Spaces:
Sleeping
Sleeping
| import logging | |
| from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelLoader: | |
| """ | |
| Loads a single translation model + tokenizer, with optional 8-bit quantization. | |
| """ | |
| def __init__(self, quantize: bool = True): | |
| self.quantize = quantize | |
| def load(self, model_name: str): | |
| # 1) Tokenizer | |
| logger.info(f"Loading tokenizer for {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
| if not hasattr(tokenizer, "lang_code_to_id"): | |
| raise AttributeError(f"Tokenizer for {model_name} has no lang_code_to_id mapping") | |
| # 2) Pipeline | |
| try: | |
| bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize) | |
| pipe = pipeline( | |
| "translation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| device_map="auto", | |
| quantization_config=bnb_cfg, | |
| ) | |
| logger.info(f"Loaded {model_name} with 8-bit quantization") | |
| except Exception as e: | |
| logger.warning(f"8-bit quantization failed ({e}), loading full-precision") | |
| pipe = pipeline( | |
| "translation", | |
| model=model_name, | |
| tokenizer=tokenizer, | |
| device_map="auto", | |
| ) | |
| logger.info(f"Loaded {model_name} in full precision") | |
| return tokenizer, pipe | |