Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import re | |
| from typing import Any | |
| import torch | |
| from app.analysis.sentence_split import normalize_trace_text | |
| from app.core.model_support import add_prefix_token_type_ids | |
| from app.core.schemas import GenerationMetadata, GenerationResult | |
| from app.generation.prompting import render_prompt | |
| THINK_BLOCK_RE = re.compile(r"<think>(.*?)</think>", re.IGNORECASE | re.DOTALL) | |
| ANSWER_MARKER_RE = re.compile(r"(?:^|\n)(?:final answer|answer)\s*:\s*", re.IGNORECASE) | |
| def _extract_trace_and_answer(text: str) -> tuple[str, str]: | |
| match = THINK_BLOCK_RE.search(text) | |
| if match: | |
| raw_trace = match.group(0) | |
| answer = text[match.end() :].strip() | |
| if not answer: | |
| answer = match.group(1).strip() | |
| return raw_trace, answer | |
| raw_trace = text.strip() | |
| answer_match = ANSWER_MARKER_RE.search(text) | |
| if answer_match: | |
| answer = text[answer_match.end() :].strip() | |
| else: | |
| paragraphs = [part.strip() for part in text.split("\n\n") if part.strip()] | |
| answer = paragraphs[-1] if paragraphs else raw_trace | |
| return raw_trace, answer | |
| def generate_answer_and_trace( | |
| *, | |
| question: str, | |
| model_name: str, | |
| model: Any, | |
| tokenizer: Any, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.6, | |
| top_p: float = 0.95, | |
| ) -> GenerationResult: | |
| prompt_text = render_prompt(tokenizer, question) | |
| encoded = tokenizer(prompt_text, return_tensors="pt") | |
| encoded = add_prefix_token_type_ids(model, encoded) | |
| model_device = next(model.parameters()).device | |
| encoded = {key: value.to(model_device) for key, value in encoded.items()} | |
| input_length = int(encoded["input_ids"].shape[-1]) | |
| do_sample = temperature > 0.0 | |
| generation_kwargs: dict[str, Any] = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": do_sample, | |
| "top_p": top_p, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| "eos_token_id": tokenizer.eos_token_id, | |
| } | |
| if do_sample: | |
| generation_kwargs["temperature"] = temperature | |
| with torch.no_grad(): | |
| output_ids = model.generate(**encoded, **generation_kwargs) | |
| generated_ids = output_ids[0, input_length:] | |
| generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False) | |
| raw_trace_text, answer = _extract_trace_and_answer(generated_text) | |
| normalized_trace_text = normalize_trace_text(raw_trace_text) | |
| return GenerationResult( | |
| question=question, | |
| model_name=model_name, | |
| answer=answer, | |
| raw_generation_text=generated_text, | |
| raw_trace_text=raw_trace_text, | |
| normalized_trace_text=normalized_trace_text, | |
| generation_metadata=GenerationMetadata( | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=do_sample, | |
| ), | |
| ) | |