"""Text generation and output parsing.""" from __future__ import annotations from dataclasses import dataclass from typing import List, Tuple from torch.nn.functional import softmax from src.config import settings @dataclass class GenerationResult: code: str explanation: str confidence: float important_tokens: List[str] def _split_code_and_explanation(text: str) -> Tuple[str, str]: marker = "Explanation:" if marker in text: code, explanation = text.split(marker, 1) return code.strip(), explanation.strip() return text.strip(), "Model did not provide explicit explanation." def generate_response(model_bundle, prompt: str) -> GenerationResult: """Generate model response with token-level confidence signals.""" if getattr(model_bundle, "is_mock", False): # Keep API runnable even when model download/loading is unavailable. fallback_code = ( "def solve_task(input_data):\n" " \"\"\"Fallback implementation when model is unavailable.\"\"\"\n" " return input_data\n" ) fallback_explanation = ( "Running in mock fallback mode because no pretrained model could be loaded. " "Set MODEL_NAME/FALLBACK_MODEL_NAME and ensure network/model access." ) load_error = getattr(model_bundle, "load_error", "") if load_error: fallback_explanation = f"{fallback_explanation}\n\nLoader error: {load_error}" return GenerationResult( code=fallback_code, explanation=fallback_explanation, confidence=0.15, important_tokens=[""], ) tokenizer = model_bundle.tokenizer model = model_bundle.model inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=settings.max_new_tokens, temperature=settings.temperature, top_p=settings.top_p, do_sample=True, return_dict_in_generate=True, output_scores=True, ) generated_ids = outputs.sequences[0][inputs["input_ids"].shape[1] :] generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) token_probs = [] important = [] for step_scores, token_id in zip(outputs.scores, generated_ids): probs = softmax(step_scores[0], dim=-1) p = probs[token_id].item() token_probs.append(p) if p < 0.30: important.append(tokenizer.decode([token_id])) confidence = float(sum(token_probs) / max(len(token_probs), 1)) code, explanation = _split_code_and_explanation(generated_text) return GenerationResult( code=code, explanation=explanation, confidence=confidence, important_tokens=important[:20], )