| """ |
| Generation utilities for text correction. |
| Handles beam search, constrained decoding, and post-generation cleanup. |
| """ |
|
|
| import torch |
| from transformers import PreTrainedModel, PreTrainedTokenizer |
| from typing import Dict, Optional, List |
| from loguru import logger |
|
|
|
|
| def generate_correction( |
| model: PreTrainedModel, |
| tokenizer: PreTrainedTokenizer, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| generation_config: Dict, |
| ) -> str: |
| """Generate corrected text from input tokens.""" |
| |
| gen_kwargs = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "max_new_tokens": generation_config.get("max_new_tokens", 512), |
| "num_beams": generation_config.get("num_beams", 5), |
| "length_penalty": generation_config.get("length_penalty", 1.0), |
| "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3), |
| "min_length": generation_config.get("min_length", 10), |
| "early_stopping": generation_config.get("early_stopping", True), |
| } |
|
|
| |
| if generation_config.get("do_sample", False): |
| gen_kwargs["do_sample"] = True |
| gen_kwargs["temperature"] = generation_config.get("temperature", 0.7) |
| gen_kwargs["top_p"] = generation_config.get("top_p", 0.9) |
| else: |
| gen_kwargs["do_sample"] = False |
|
|
| with torch.no_grad(): |
| output_ids = model.generate(**gen_kwargs) |
|
|
| |
| generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| return generated_text.strip() |
|
|
|
|
| def batch_generate( |
| model: PreTrainedModel, |
| tokenizer: PreTrainedTokenizer, |
| texts: List[str], |
| generation_config: Dict, |
| max_length: int = 512, |
| ) -> List[str]: |
| """Generate corrections for a batch of texts.""" |
| if not texts: |
| return [] |
|
|
| results = [] |
| |
| batch_size = generation_config.get("batch_size", 4) |
|
|
| for i in range(0, len(texts), batch_size): |
| batch_texts = texts[i:i + batch_size] |
|
|
| |
| inputs = tokenizer( |
| batch_texts, |
| max_length=max_length, |
| padding=True, |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| |
| device = next(model.parameters()).device |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| |
| gen_kwargs = { |
| "max_new_tokens": generation_config.get("max_new_tokens", 512), |
| "num_beams": generation_config.get("num_beams", 5), |
| "length_penalty": generation_config.get("length_penalty", 1.0), |
| "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3), |
| "early_stopping": generation_config.get("early_stopping", True), |
| } |
|
|
| if generation_config.get("do_sample", False): |
| gen_kwargs["do_sample"] = True |
| gen_kwargs["temperature"] = generation_config.get("temperature", 0.7) |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| **gen_kwargs, |
| ) |
|
|
| |
| for output in output_ids: |
| text = tokenizer.decode(output, skip_special_tokens=True) |
| results.append(text.strip()) |
|
|
| logger.debug(f"Generated batch {i // batch_size + 1}: {len(batch_texts)} texts") |
|
|
| return results |
|
|