""" Generation utilities for text correction. Handles beam search, constrained decoding, and post-generation cleanup. """ import torch from transformers import PreTrainedModel, PreTrainedTokenizer from typing import Dict, Optional, List from loguru import logger def generate_correction( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor, generation_config: Dict, ) -> str: """Generate corrected text from input tokens.""" # Build generation kwargs from config gen_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "max_new_tokens": generation_config.get("max_new_tokens", 512), "num_beams": generation_config.get("num_beams", 5), "length_penalty": generation_config.get("length_penalty", 1.0), "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3), "min_length": generation_config.get("min_length", 10), "early_stopping": generation_config.get("early_stopping", True), } # Optional sampling parameters if generation_config.get("do_sample", False): gen_kwargs["do_sample"] = True gen_kwargs["temperature"] = generation_config.get("temperature", 0.7) gen_kwargs["top_p"] = generation_config.get("top_p", 0.9) else: gen_kwargs["do_sample"] = False with torch.no_grad(): output_ids = model.generate(**gen_kwargs) # Decode, skipping special tokens generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) return generated_text.strip() def batch_generate( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, texts: List[str], generation_config: Dict, max_length: int = 512, ) -> List[str]: """Generate corrections for a batch of texts.""" if not texts: return [] results = [] # Process in mini-batches to manage memory on CPU batch_size = generation_config.get("batch_size", 4) for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] # Tokenise batch inputs = tokenizer( batch_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt", ) # Move to model device device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Generate gen_kwargs = { "max_new_tokens": generation_config.get("max_new_tokens", 512), "num_beams": generation_config.get("num_beams", 5), "length_penalty": generation_config.get("length_penalty", 1.0), "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3), "early_stopping": generation_config.get("early_stopping", True), } if generation_config.get("do_sample", False): gen_kwargs["do_sample"] = True gen_kwargs["temperature"] = generation_config.get("temperature", 0.7) with torch.no_grad(): output_ids = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], **gen_kwargs, ) # Decode each output for output in output_ids: text = tokenizer.decode(output, skip_special_tokens=True) results.append(text.strip()) logger.debug(f"Generated batch {i // batch_size + 1}: {len(batch_texts)} texts") return results