cot-anc / app /generation /service.py
BART-ender's picture
Switch default model to HRM-Text-1B
2620860 verified
from __future__ import annotations
import re
from typing import Any
import torch
from app.analysis.sentence_split import normalize_trace_text
from app.core.model_support import add_prefix_token_type_ids
from app.core.schemas import GenerationMetadata, GenerationResult
from app.generation.prompting import render_prompt
THINK_BLOCK_RE = re.compile(r"<think>(.*?)</think>", re.IGNORECASE | re.DOTALL)
ANSWER_MARKER_RE = re.compile(r"(?:^|\n)(?:final answer|answer)\s*:\s*", re.IGNORECASE)
def _extract_trace_and_answer(text: str) -> tuple[str, str]:
match = THINK_BLOCK_RE.search(text)
if match:
raw_trace = match.group(0)
answer = text[match.end() :].strip()
if not answer:
answer = match.group(1).strip()
return raw_trace, answer
raw_trace = text.strip()
answer_match = ANSWER_MARKER_RE.search(text)
if answer_match:
answer = text[answer_match.end() :].strip()
else:
paragraphs = [part.strip() for part in text.split("\n\n") if part.strip()]
answer = paragraphs[-1] if paragraphs else raw_trace
return raw_trace, answer
def generate_answer_and_trace(
*,
question: str,
model_name: str,
model: Any,
tokenizer: Any,
max_new_tokens: int = 512,
temperature: float = 0.6,
top_p: float = 0.95,
) -> GenerationResult:
prompt_text = render_prompt(tokenizer, question)
encoded = tokenizer(prompt_text, return_tensors="pt")
encoded = add_prefix_token_type_ids(model, encoded)
model_device = next(model.parameters()).device
encoded = {key: value.to(model_device) for key, value in encoded.items()}
input_length = int(encoded["input_ids"].shape[-1])
do_sample = temperature > 0.0
generation_kwargs: dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"top_p": top_p,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
if do_sample:
generation_kwargs["temperature"] = temperature
with torch.no_grad():
output_ids = model.generate(**encoded, **generation_kwargs)
generated_ids = output_ids[0, input_length:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
raw_trace_text, answer = _extract_trace_and_answer(generated_text)
normalized_trace_text = normalize_trace_text(raw_trace_text)
return GenerationResult(
question=question,
model_name=model_name,
answer=answer,
raw_generation_text=generated_text,
raw_trace_text=raw_trace_text,
normalized_trace_text=normalized_trace_text,
generation_metadata=GenerationMetadata(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
),
)