Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any | |
| import torch | |
| from app.analysis.sentence_split import split_sentences | |
| from app.analysis.summaries import ( | |
| compute_incoming_importance, | |
| compute_outgoing_importance, | |
| compute_top_edges, | |
| ) | |
| from app.analysis.suppression import compute_attribution_matrix | |
| from app.analysis.token_boundaries import tokenize_with_sentence_ranges, truncate_text_to_token_limit | |
| from app.analysis.validation import validate_top_edges | |
| from app.core.config import get_settings | |
| from app.core.runtime import load_model_bundle | |
| from app.core.schemas import AnalysisResult, GenerationResult | |
| from app.generation.service import generate_answer_and_trace | |
| def compute_attribution_analysis( | |
| *, | |
| question: str, | |
| model_name: str | None = None, | |
| take_log: bool | None = None, | |
| max_sentences: int | None = None, | |
| max_trace_tokens: int | None = None, | |
| validate_top_k: int = 0, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.6, | |
| top_p: float = 0.95, | |
| device_preference: str | None = None, | |
| dtype_preference: str | None = None, | |
| attn_implementation: str | None = None, | |
| trust_remote_code: bool | None = None, | |
| low_cpu_mem_usage: bool | None = None, | |
| ) -> AnalysisResult: | |
| generation = None | |
| return analyze_generation_result( | |
| question=question, | |
| generation=generation, | |
| model_name=model_name, | |
| take_log=take_log, | |
| max_sentences=max_sentences, | |
| max_trace_tokens=max_trace_tokens, | |
| validate_top_k=validate_top_k, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| device_preference=device_preference, | |
| dtype_preference=dtype_preference, | |
| attn_implementation=attn_implementation, | |
| trust_remote_code=trust_remote_code, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| ) | |
| def analyze_generation_result( | |
| *, | |
| question: str, | |
| generation: GenerationResult | None = None, | |
| model_name: str | None = None, | |
| take_log: bool | None = None, | |
| max_sentences: int | None = None, | |
| max_trace_tokens: int | None = None, | |
| validate_top_k: int = 0, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.6, | |
| top_p: float = 0.95, | |
| device_preference: str | None = None, | |
| dtype_preference: str | None = None, | |
| attn_implementation: str | None = None, | |
| trust_remote_code: bool | None = None, | |
| low_cpu_mem_usage: bool | None = None, | |
| ) -> AnalysisResult: | |
| settings = get_settings() | |
| resolved_model_name = model_name or settings.model_name | |
| resolved_take_log = settings.take_log if take_log is None else take_log | |
| resolved_max_sentences = max_sentences or settings.max_sentences | |
| resolved_max_trace_tokens = max_trace_tokens or settings.max_trace_tokens | |
| resolved_device = device_preference or settings.device_preference | |
| resolved_dtype = dtype_preference or settings.dtype_preference | |
| resolved_attn_implementation = attn_implementation or settings.attn_implementation | |
| resolved_trust_remote_code = settings.trust_remote_code if trust_remote_code is None else trust_remote_code | |
| resolved_low_cpu_mem_usage = ( | |
| settings.low_cpu_mem_usage if low_cpu_mem_usage is None else low_cpu_mem_usage | |
| ) | |
| bundle = load_model_bundle( | |
| resolved_model_name, | |
| device_preference=resolved_device, | |
| dtype_preference=resolved_dtype, | |
| attn_implementation=resolved_attn_implementation, | |
| trust_remote_code=resolved_trust_remote_code, | |
| low_cpu_mem_usage=resolved_low_cpu_mem_usage, | |
| ) | |
| if not bundle.capability.supports_attribution: | |
| reason = bundle.capability.reason or "Model does not support attribution analysis." | |
| raise RuntimeError(reason) | |
| if generation is None: | |
| generation = generate_answer_and_trace( | |
| question=question, | |
| model_name=resolved_model_name, | |
| model=bundle.model, | |
| tokenizer=bundle.tokenizer, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| truncated_text = truncate_text_to_token_limit( | |
| generation.normalized_trace_text, | |
| bundle.tokenizer, | |
| resolved_max_trace_tokens, | |
| ) | |
| sentence_spans = split_sentences(truncated_text) | |
| if resolved_max_sentences > 0 and len(sentence_spans) > resolved_max_sentences: | |
| sentence_spans = sentence_spans[:resolved_max_sentences] | |
| truncated_text = truncated_text[: sentence_spans[-1].end_char] | |
| sentence_spans = split_sentences(truncated_text) | |
| if not sentence_spans: | |
| raise RuntimeError("Trace normalization produced no analyzable sentences.") | |
| mapping = tokenize_with_sentence_ranges(truncated_text, sentence_spans, bundle.tokenizer) | |
| input_ids = mapping.input_ids.to(bundle.device) | |
| computation = compute_attribution_matrix( | |
| input_ids=input_ids, | |
| token_ranges=mapping.token_ranges, | |
| model=bundle.model, | |
| take_log=resolved_take_log, | |
| max_trace_tokens=resolved_max_trace_tokens, | |
| max_sentences=resolved_max_sentences, | |
| ) | |
| outgoing = compute_outgoing_importance(computation.raw_matrix) | |
| incoming = compute_incoming_importance(computation.raw_matrix) | |
| top_edges = compute_top_edges(computation.raw_matrix, top_k=10) | |
| validation = validate_top_edges( | |
| model=bundle.model, | |
| input_ids=input_ids, | |
| token_ranges=mapping.token_ranges, | |
| top_edges=top_edges, | |
| baseline_token_nll=computation.token_nll, | |
| top_k=validate_top_k, | |
| ) | |
| return AnalysisResult( | |
| question=question, | |
| model_name=resolved_model_name, | |
| answer=generation.answer, | |
| raw_trace_text=generation.raw_trace_text, | |
| normalized_trace_text=truncated_text, | |
| sentences=[span.text for span in sentence_spans], | |
| sentence_token_ranges=mapping.token_ranges, | |
| suppression_matrix=computation.matrix.tolist(), | |
| raw_suppression_matrix=computation.raw_matrix.tolist(), | |
| outgoing_importance=outgoing, | |
| incoming_importance=incoming, | |
| top_edges=top_edges, | |
| runtime_metadata=computation.runtime_metadata, | |
| validation_metadata=validation, | |
| extra_metadata={ | |
| "raw_generation_text": generation.raw_generation_text, | |
| "generation_metadata": generation.generation_metadata.model_dump(), | |
| "effective_runtime": { | |
| "device_preference": resolved_device, | |
| "dtype_preference": resolved_dtype, | |
| "attn_implementation": resolved_attn_implementation, | |
| "trust_remote_code": resolved_trust_remote_code, | |
| "low_cpu_mem_usage": resolved_low_cpu_mem_usage, | |
| }, | |
| }, | |
| ) | |