decipherai-api / models /huggingface_models.py
Akshay30's picture
Fix Greek OCR and update Latin OCR model
36331c6
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from config import Config
from utils.gpu_diagnostics import log_model_device, register_processor, reclaim_vram_for
class HuggingFaceModels:
def __init__(self):
self.config = Config()
self.device = torch.device("cpu") # Force Egyptian translator to CPU to save GPU VRAM
self._tokenizer = None
self._model = None
self.translator = self._translate_fn
print("[INFO] Egyptian translator initialized (Forced to CPU)")
def setup_translation_model(self):
"""Load T5 Seq2Seq model on CPU."""
model_name = getattr(self.config, 'HF_TRANSLATOR_MODEL', 'AnushS/Hieroglyph-Translator-Using-Gardiner-Codes')
try:
print(f"[INFO] Lazily loading Hugging Face translation model on CPU: {model_name}...")
import os
HF_TOKEN = os.getenv("HF_TOKEN")
self._tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
self._model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
self._model.to(self.device)
self._model.eval()
log_model_device("Egyptian T5 Translator", self.device)
print("[INFO] Translation model loaded successfully on CPU (Seq2Seq direct)")
except Exception as e:
print(f"[ERROR] Failed to load translation model '{model_name}': {e}")
self.translator = self._get_mock_translator()
def _translate_fn(self, prompt, max_new_tokens=128, **kwargs):
"""Translate using the T5 model directly on CPU."""
try:
if self._model is None:
self.setup_translation_model()
inputs = self._tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.inference_mode():
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=kwargs.get("num_beams", 4),
do_sample=kwargs.get("do_sample", False),
)
decoded = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": decoded, "translation_text": decoded}]
except Exception as e:
print(f"[ERROR] Translation inference failed: {e}")
return [{"generated_text": "", "translation_text": ""}]
def get_translator(self):
"""Return the loaded translation function or mock fallback"""
return self.translator
def _get_mock_translator(self):
"""Returns a dummy translator function that mimics pipeline behavior on error"""
print("[INFO] Setting up mock fallback translator")
def mock_pipeline(prompt, *args, **kwargs):
return [{"generated_text": "", "translation_text": ""}]
return mock_pipeline