| | import torch |
| | from typing import List |
| | from tqdm import tqdm |
| |
|
| | from .llm_iface import LLM |
| | from .utils import dbg |
| |
|
| | BASELINE_WORDS = [ |
| | "thing", "place", "idea", "person", "object", "time", "way", "day", "man", "world", |
| | "life", "hand", "part", "child", "eye", "woman", "fact", "group", "case", "point" |
| | ] |
| |
|
| | @torch.no_grad() |
| | def _get_last_token_hidden_state(llm: LLM, prompt: str) -> torch.Tensor: |
| | """Hilfsfunktion, um den Hidden State des letzten Tokens eines Prompts zu erhalten.""" |
| | inputs = llm.tokenizer(prompt, return_tensors="pt").to(llm.model.device) |
| | with torch.no_grad(): |
| | outputs = llm.model(**inputs, output_hidden_states=True) |
| | last_hidden_state = outputs.hidden_states[-1][0, -1, :].cpu() |
| |
|
| | |
| | expected_size = llm.stable_config.hidden_dim |
| |
|
| | assert last_hidden_state.shape == (expected_size,), \ |
| | f"Hidden state shape mismatch. Expected {(expected_size,)}, got {last_hidden_state.shape}" |
| | return last_hidden_state |
| |
|
| | @torch.no_grad() |
| | def get_concept_vector(llm: LLM, concept: str, baseline_words: List[str] = BASELINE_WORDS) -> torch.Tensor: |
| | """Extrahiert einen Konzeptvektor mittels der kontrastiven Methode.""" |
| | dbg(f"Extracting contrastive concept vector for '{concept}'...") |
| | prompt_template = "Here is a sentence about the concept of {}." |
| | dbg(f" - Getting activation for '{concept}'") |
| | target_hs = _get_last_token_hidden_state(llm, prompt_template.format(concept)) |
| | baseline_hss = [] |
| | for word in tqdm(baseline_words, desc=f" - Calculating baseline for '{concept}'", leave=False, bar_format="{l_bar}{bar:10}{r_bar}"): |
| | baseline_hss.append(_get_last_token_hidden_state(llm, prompt_template.format(word))) |
| | assert all(hs.shape == target_hs.shape for hs in baseline_hss) |
| | mean_baseline_hs = torch.stack(baseline_hss).mean(dim=0) |
| | dbg(f" - Mean baseline vector computed with norm {torch.norm(mean_baseline_hs).item():.2f}") |
| | concept_vector = target_hs - mean_baseline_hs |
| | norm = torch.norm(concept_vector).item() |
| | dbg(f"Concept vector for '{concept}' extracted with norm {norm:.2f}.") |
| | assert torch.isfinite(concept_vector).all() |
| | return concept_vector |
| |
|