rewrite / src /model /generation_utils.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Generation utilities for text correction.
Handles beam search, constrained decoding, and post-generation cleanup.
"""
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import Dict, Optional, List
from loguru import logger
def generate_correction(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
generation_config: Dict,
) -> str:
"""Generate corrected text from input tokens."""
# Build generation kwargs from config
gen_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": generation_config.get("max_new_tokens", 512),
"num_beams": generation_config.get("num_beams", 5),
"length_penalty": generation_config.get("length_penalty", 1.0),
"no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
"min_length": generation_config.get("min_length", 10),
"early_stopping": generation_config.get("early_stopping", True),
}
# Optional sampling parameters
if generation_config.get("do_sample", False):
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)
gen_kwargs["top_p"] = generation_config.get("top_p", 0.9)
else:
gen_kwargs["do_sample"] = False
with torch.no_grad():
output_ids = model.generate(**gen_kwargs)
# Decode, skipping special tokens
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_text.strip()
def batch_generate(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
texts: List[str],
generation_config: Dict,
max_length: int = 512,
) -> List[str]:
"""Generate corrections for a batch of texts."""
if not texts:
return []
results = []
# Process in mini-batches to manage memory on CPU
batch_size = generation_config.get("batch_size", 4)
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Tokenise batch
inputs = tokenizer(
batch_texts,
max_length=max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
# Move to model device
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
gen_kwargs = {
"max_new_tokens": generation_config.get("max_new_tokens", 512),
"num_beams": generation_config.get("num_beams", 5),
"length_penalty": generation_config.get("length_penalty", 1.0),
"no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
"early_stopping": generation_config.get("early_stopping", True),
}
if generation_config.get("do_sample", False):
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)
with torch.no_grad():
output_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# Decode each output
for output in output_ids:
text = tokenizer.decode(output, skip_special_tokens=True)
results.append(text.strip())
logger.debug(f"Generated batch {i // batch_size + 1}: {len(batch_texts)} texts")
return results