File size: 3,035 Bytes
2f4af3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36331c6
 
 
 
2f4af3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from config import Config
from utils.gpu_diagnostics import log_model_device, register_processor, reclaim_vram_for

class HuggingFaceModels:
    def __init__(self):
        self.config = Config()
        self.device = torch.device("cpu")  # Force Egyptian translator to CPU to save GPU VRAM
        self._tokenizer = None
        self._model = None
        self.translator = self._translate_fn
        print("[INFO] Egyptian translator initialized (Forced to CPU)")

    def setup_translation_model(self):
        """Load T5 Seq2Seq model on CPU."""
        model_name = getattr(self.config, 'HF_TRANSLATOR_MODEL', 'AnushS/Hieroglyph-Translator-Using-Gardiner-Codes')
        try:
            print(f"[INFO] Lazily loading Hugging Face translation model on CPU: {model_name}...")
            import os
            HF_TOKEN = os.getenv("HF_TOKEN")
            self._tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
            self._model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
            self._model.to(self.device)
            self._model.eval()
            log_model_device("Egyptian T5 Translator", self.device)
            print("[INFO] Translation model loaded successfully on CPU (Seq2Seq direct)")
        except Exception as e:
            print(f"[ERROR] Failed to load translation model '{model_name}': {e}")
            self.translator = self._get_mock_translator()

    def _translate_fn(self, prompt, max_new_tokens=128, **kwargs):
        """Translate using the T5 model directly on CPU."""
        try:
            if self._model is None:
                self.setup_translation_model()
            
            inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.inference_mode():
                outputs = self._model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=kwargs.get("num_beams", 4),
                    do_sample=kwargs.get("do_sample", False),
                )
            
            decoded = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
            return [{"generated_text": decoded, "translation_text": decoded}]
        except Exception as e:
            print(f"[ERROR] Translation inference failed: {e}")
            return [{"generated_text": "", "translation_text": ""}]

    def get_translator(self):
        """Return the loaded translation function or mock fallback"""
        return self.translator

    def _get_mock_translator(self):
        """Returns a dummy translator function that mimics pipeline behavior on error"""
        print("[INFO] Setting up mock fallback translator")
        def mock_pipeline(prompt, *args, **kwargs):
            return [{"generated_text": "", "translation_text": ""}]
        return mock_pipeline