from __future__ import annotations from typing import Any import torch from app.analysis.sentence_split import split_sentences from app.analysis.summaries import ( compute_incoming_importance, compute_outgoing_importance, compute_top_edges, ) from app.analysis.suppression import compute_attribution_matrix from app.analysis.token_boundaries import tokenize_with_sentence_ranges, truncate_text_to_token_limit from app.analysis.validation import validate_top_edges from app.core.config import get_settings from app.core.runtime import load_model_bundle from app.core.schemas import AnalysisResult, GenerationResult from app.generation.service import generate_answer_and_trace def compute_attribution_analysis( *, question: str, model_name: str | None = None, take_log: bool | None = None, max_sentences: int | None = None, max_trace_tokens: int | None = None, validate_top_k: int = 0, max_new_tokens: int = 512, temperature: float = 0.6, top_p: float = 0.95, device_preference: str | None = None, dtype_preference: str | None = None, attn_implementation: str | None = None, trust_remote_code: bool | None = None, low_cpu_mem_usage: bool | None = None, ) -> AnalysisResult: generation = None return analyze_generation_result( question=question, generation=generation, model_name=model_name, take_log=take_log, max_sentences=max_sentences, max_trace_tokens=max_trace_tokens, validate_top_k=validate_top_k, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, device_preference=device_preference, dtype_preference=dtype_preference, attn_implementation=attn_implementation, trust_remote_code=trust_remote_code, low_cpu_mem_usage=low_cpu_mem_usage, ) def analyze_generation_result( *, question: str, generation: GenerationResult | None = None, model_name: str | None = None, take_log: bool | None = None, max_sentences: int | None = None, max_trace_tokens: int | None = None, validate_top_k: int = 0, max_new_tokens: int = 512, temperature: float = 0.6, top_p: float = 0.95, device_preference: str | None = None, dtype_preference: str | None = None, attn_implementation: str | None = None, trust_remote_code: bool | None = None, low_cpu_mem_usage: bool | None = None, ) -> AnalysisResult: settings = get_settings() resolved_model_name = model_name or settings.model_name resolved_take_log = settings.take_log if take_log is None else take_log resolved_max_sentences = max_sentences or settings.max_sentences resolved_max_trace_tokens = max_trace_tokens or settings.max_trace_tokens resolved_device = device_preference or settings.device_preference resolved_dtype = dtype_preference or settings.dtype_preference resolved_attn_implementation = attn_implementation or settings.attn_implementation resolved_trust_remote_code = settings.trust_remote_code if trust_remote_code is None else trust_remote_code resolved_low_cpu_mem_usage = ( settings.low_cpu_mem_usage if low_cpu_mem_usage is None else low_cpu_mem_usage ) bundle = load_model_bundle( resolved_model_name, device_preference=resolved_device, dtype_preference=resolved_dtype, attn_implementation=resolved_attn_implementation, trust_remote_code=resolved_trust_remote_code, low_cpu_mem_usage=resolved_low_cpu_mem_usage, ) if not bundle.capability.supports_attribution: reason = bundle.capability.reason or "Model does not support attribution analysis." raise RuntimeError(reason) if generation is None: generation = generate_answer_and_trace( question=question, model_name=resolved_model_name, model=bundle.model, tokenizer=bundle.tokenizer, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) truncated_text = truncate_text_to_token_limit( generation.normalized_trace_text, bundle.tokenizer, resolved_max_trace_tokens, ) sentence_spans = split_sentences(truncated_text) if resolved_max_sentences > 0 and len(sentence_spans) > resolved_max_sentences: sentence_spans = sentence_spans[:resolved_max_sentences] truncated_text = truncated_text[: sentence_spans[-1].end_char] sentence_spans = split_sentences(truncated_text) if not sentence_spans: raise RuntimeError("Trace normalization produced no analyzable sentences.") mapping = tokenize_with_sentence_ranges(truncated_text, sentence_spans, bundle.tokenizer) input_ids = mapping.input_ids.to(bundle.device) computation = compute_attribution_matrix( input_ids=input_ids, token_ranges=mapping.token_ranges, model=bundle.model, take_log=resolved_take_log, max_trace_tokens=resolved_max_trace_tokens, max_sentences=resolved_max_sentences, ) outgoing = compute_outgoing_importance(computation.raw_matrix) incoming = compute_incoming_importance(computation.raw_matrix) top_edges = compute_top_edges(computation.raw_matrix, top_k=10) validation = validate_top_edges( model=bundle.model, input_ids=input_ids, token_ranges=mapping.token_ranges, top_edges=top_edges, baseline_token_nll=computation.token_nll, top_k=validate_top_k, ) return AnalysisResult( question=question, model_name=resolved_model_name, answer=generation.answer, raw_trace_text=generation.raw_trace_text, normalized_trace_text=truncated_text, sentences=[span.text for span in sentence_spans], sentence_token_ranges=mapping.token_ranges, suppression_matrix=computation.matrix.tolist(), raw_suppression_matrix=computation.raw_matrix.tolist(), outgoing_importance=outgoing, incoming_importance=incoming, top_edges=top_edges, runtime_metadata=computation.runtime_metadata, validation_metadata=validation, extra_metadata={ "raw_generation_text": generation.raw_generation_text, "generation_metadata": generation.generation_metadata.model_dump(), "effective_runtime": { "device_preference": resolved_device, "dtype_preference": resolved_dtype, "attn_implementation": resolved_attn_implementation, "trust_remote_code": resolved_trust_remote_code, "low_cpu_mem_usage": resolved_low_cpu_mem_usage, }, }, )