File size: 3,584 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Generation utilities for text correction.
Handles beam search, constrained decoding, and post-generation cleanup.
"""

import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import Dict, Optional, List
from loguru import logger


def generate_correction(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    generation_config: Dict,
) -> str:
    """Generate corrected text from input tokens."""
    # Build generation kwargs from config
    gen_kwargs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "max_new_tokens": generation_config.get("max_new_tokens", 512),
        "num_beams": generation_config.get("num_beams", 5),
        "length_penalty": generation_config.get("length_penalty", 1.0),
        "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
        "min_length": generation_config.get("min_length", 10),
        "early_stopping": generation_config.get("early_stopping", True),
    }

    # Optional sampling parameters
    if generation_config.get("do_sample", False):
        gen_kwargs["do_sample"] = True
        gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)
        gen_kwargs["top_p"] = generation_config.get("top_p", 0.9)
    else:
        gen_kwargs["do_sample"] = False

    with torch.no_grad():
        output_ids = model.generate(**gen_kwargs)

    # Decode, skipping special tokens
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return generated_text.strip()


def batch_generate(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    texts: List[str],
    generation_config: Dict,
    max_length: int = 512,
) -> List[str]:
    """Generate corrections for a batch of texts."""
    if not texts:
        return []

    results = []
    # Process in mini-batches to manage memory on CPU
    batch_size = generation_config.get("batch_size", 4)

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]

        # Tokenise batch
        inputs = tokenizer(
            batch_texts,
            max_length=max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )

        # Move to model device
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Generate
        gen_kwargs = {
            "max_new_tokens": generation_config.get("max_new_tokens", 512),
            "num_beams": generation_config.get("num_beams", 5),
            "length_penalty": generation_config.get("length_penalty", 1.0),
            "no_repeat_ngram_size": generation_config.get("no_repeat_ngram_size", 3),
            "early_stopping": generation_config.get("early_stopping", True),
        }

        if generation_config.get("do_sample", False):
            gen_kwargs["do_sample"] = True
            gen_kwargs["temperature"] = generation_config.get("temperature", 0.7)

        with torch.no_grad():
            output_ids = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **gen_kwargs,
            )

        # Decode each output
        for output in output_ids:
            text = tokenizer.decode(output, skip_special_tokens=True)
            results.append(text.strip())

        logger.debug(f"Generated batch {i // batch_size + 1}: {len(batch_texts)} texts")

    return results