kleervoyans commited on
Commit
ff2d1f9
·
verified ·
1 Parent(s): 922e95d

Create model_loader.py

Browse files
Files changed (1) hide show
  1. models/model_loader.py +40 -0
models/model_loader.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import AutoTokenizer, pipeline, BitsAndBytesConfig
3
+
4
+ logging.basicConfig(level=logging.INFO)
5
+ logger = logging.getLogger(__name__)
6
+
7
+ class ModelLoader:
8
+ """
9
+ Loads a single translation model + tokenizer, with optional 8-bit quantization.
10
+ """
11
+ def __init__(self, quantize: bool = True):
12
+ self.quantize = quantize
13
+
14
+ def load(self, model_name: str):
15
+ # 1) Tokenizer
16
+ logger.info(f"Loading tokenizer for {model_name}")
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
18
+ if not hasattr(tokenizer, "lang_code_to_id"):
19
+ raise AttributeError(f"Tokenizer for {model_name} has no lang_code_to_id mapping")
20
+ # 2) Pipeline
21
+ try:
22
+ bnb_cfg = BitsAndBytesConfig(load_in_8bit=self.quantize)
23
+ pipe = pipeline(
24
+ "translation",
25
+ model=model_name,
26
+ tokenizer=tokenizer,
27
+ device_map="auto",
28
+ quantization_config=bnb_cfg,
29
+ )
30
+ logger.info(f"Loaded {model_name} with 8-bit quantization")
31
+ except Exception as e:
32
+ logger.warning(f"8-bit quantization failed ({e}), loading full-precision")
33
+ pipe = pipeline(
34
+ "translation",
35
+ model=model_name,
36
+ tokenizer=tokenizer,
37
+ device_map="auto",
38
+ )
39
+ logger.info(f"Loaded {model_name} in full precision")
40
+ return tokenizer, pipe