from __future__ import annotations import re from typing import Any import torch from app.analysis.sentence_split import normalize_trace_text from app.core.model_support import add_prefix_token_type_ids from app.core.schemas import GenerationMetadata, GenerationResult from app.generation.prompting import render_prompt THINK_BLOCK_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) ANSWER_MARKER_RE = re.compile(r"(?:^|\n)(?:final answer|answer)\s*:\s*", re.IGNORECASE) def _extract_trace_and_answer(text: str) -> tuple[str, str]: match = THINK_BLOCK_RE.search(text) if match: raw_trace = match.group(0) answer = text[match.end() :].strip() if not answer: answer = match.group(1).strip() return raw_trace, answer raw_trace = text.strip() answer_match = ANSWER_MARKER_RE.search(text) if answer_match: answer = text[answer_match.end() :].strip() else: paragraphs = [part.strip() for part in text.split("\n\n") if part.strip()] answer = paragraphs[-1] if paragraphs else raw_trace return raw_trace, answer def generate_answer_and_trace( *, question: str, model_name: str, model: Any, tokenizer: Any, max_new_tokens: int = 512, temperature: float = 0.6, top_p: float = 0.95, ) -> GenerationResult: prompt_text = render_prompt(tokenizer, question) encoded = tokenizer(prompt_text, return_tensors="pt") encoded = add_prefix_token_type_ids(model, encoded) model_device = next(model.parameters()).device encoded = {key: value.to(model_device) for key, value in encoded.items()} input_length = int(encoded["input_ids"].shape[-1]) do_sample = temperature > 0.0 generation_kwargs: dict[str, Any] = { "max_new_tokens": max_new_tokens, "do_sample": do_sample, "top_p": top_p, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, } if do_sample: generation_kwargs["temperature"] = temperature with torch.no_grad(): output_ids = model.generate(**encoded, **generation_kwargs) generated_ids = output_ids[0, input_length:] generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False) raw_trace_text, answer = _extract_trace_and_answer(generated_text) normalized_trace_text = normalize_trace_text(raw_trace_text) return GenerationResult( question=question, model_name=model_name, answer=answer, raw_generation_text=generated_text, raw_trace_text=raw_trace_text, normalized_trace_text=normalized_trace_text, generation_metadata=GenerationMetadata( max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, ), )