Pikeras commited on
Commit
06cebb7
verified
1 Parent(s): 427fa9b

Create model_manager.py

Browse files
Files changed (1) hide show
  1. src/modules/model_manager.py +65 -0
src/modules/model_manager.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ try:
6
+ from utils.reproducibility import set_seed
7
+ except ModuleNotFoundError:
8
+ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
9
+ from utils.reproducibility import set_seed
10
+ # Establecer una semilla para reproducibilidad (descomenta para activar)
11
+ # set_seed(72)
12
+
13
+ # Clase encargada de cargar y gestionar los modelos y tokenizadores utilizados en el proyecto.
14
+ class ModelManager:
15
+ def __init__(self, config):
16
+ self.config = config
17
+ self.models = {}
18
+ self.tokenizers = {}
19
+
20
+ # Carga el modelo generador de prompts y su tokenizer con cuantizaci贸n 4-bit (bitsandbytes)
21
+ def load_generator(self):
22
+ modelo_id = self.config.modelo_generador['id_modelo']
23
+ modo = self.config.modelo_generador['modo_interaccion']
24
+ if modelo_id != '' and str(modo).lower() == 'local':
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_compute_dtype="float16",
29
+ bnb_4bit_quant_type="nf4"
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained(modelo_id, use_fast=True)
32
+ model = AutoModelForCausalLM.from_pretrained(modelo_id, quantization_config=bnb_config, low_cpu_mem_usage=True, device_map="auto")
33
+ self.models['generator'] = model
34
+ self.tokenizers['generator'] = tokenizer
35
+
36
+ # Carga el modelo LLM a evaluar y su tokenizer con cuantizaci贸n 4-bit.
37
+ def load_evaluator(self):
38
+ modelo_id = self.config.modelo_a_evaluar['id_modelo']
39
+ modo = self.config.modelo_a_evaluar['modo_interaccion']
40
+ if modelo_id != '' and str(modo).lower() == 'local':
41
+ bnb_config = BitsAndBytesConfig(
42
+ load_in_4bit=True,
43
+ bnb_4bit_use_double_quant=True,
44
+ bnb_4bit_compute_dtype="float16",
45
+ bnb_4bit_quant_type="nf4"
46
+ )
47
+ tokenizer = AutoTokenizer.from_pretrained(modelo_id, use_fast=True)
48
+ model = AutoModelForCausalLM.from_pretrained(modelo_id, quantization_config=bnb_config, low_cpu_mem_usage=True, device_map="auto")
49
+ self.models['evaluator'] = model
50
+ self.tokenizers['evaluator'] = tokenizer
51
+
52
+ # Carga el modelo de an谩lisis de sentimiento (SequenceClassification) en CUDA y lo pone en eval().
53
+ def load_sentiment(self):
54
+ modelo_id = self.config.modelo_analisis_de_sentimiento['id_modelo']
55
+ modo = self.config.modelo_analisis_de_sentimiento['modo_interaccion']
56
+ if modelo_id != '' and str(modo).lower() == 'local':
57
+ tokenizer = AutoTokenizer.from_pretrained(modelo_id, use_fast=True)
58
+ model = AutoModelForSequenceClassification.from_pretrained(modelo_id, low_cpu_mem_usage=True)
59
+ model = model.to("cuda")
60
+ model.eval()
61
+ self.models['sentiment'] = model
62
+ self.tokenizers['sentiment'] = tokenizer
63
+
64
+ def get_model(self, key):
65
+ return self.models.get(key), self.tokenizers.get(key)