ZedLow commited on
Commit
9b51450
·
verified ·
1 Parent(s): d9a3b5a

Update rag/models.py

Browse files
Files changed (1) hide show
  1. rag/models.py +18 -21
rag/models.py CHANGED
@@ -6,57 +6,54 @@ from transformers import (
6
  Qwen2VLForConditionalGeneration,
7
  AutoProcessor,
8
  )
9
-
10
  from gliner import GLiNER
 
11
  from rag.config import Settings
12
  from rag.logging_utils import get_logger
13
 
14
  logger = get_logger(__name__)
15
 
16
- def _has_cuda() -> bool:
17
- return torch.cuda.is_available()
18
-
19
- def _dtype_for_device():
20
- # BF16 is great on recent GPUs; on CPU it's often slower / unsupported in ops
21
- return torch.bfloat16 if _has_cuda() else torch.float32
22
-
23
  class Models:
24
  def __init__(self, settings: Settings):
25
  self.settings = settings
26
 
27
- # Router (CPU)
28
- logger.info("Loading Router (GLiNER) on CPU: %s", settings.router_model_id)
29
- self.router = GLiNER.from_pretrained(settings.router_model_id).to("cpu")
30
- self.router.eval()
31
 
32
- # Embedder
33
- logger.info("Loading Embedder: %s", settings.embed_model_id)
34
  self.embed_tokenizer = AutoTokenizer.from_pretrained(settings.embed_model_id, trust_remote_code=False)
35
  self.embed_model = AutoModel.from_pretrained(
36
  settings.embed_model_id,
37
  trust_remote_code=False,
38
- torch_dtype=_dtype_for_device(),
39
  device_map="auto",
40
  )
41
  self.embed_model.eval()
42
 
43
  # Reranker
44
- logger.info("Loading Reranker: %s", settings.rerank_model_id)
45
  self.rerank_tokenizer = AutoTokenizer.from_pretrained(settings.rerank_model_id)
46
  self.rerank_model = AutoModelForSequenceClassification.from_pretrained(
47
  settings.rerank_model_id,
48
- torch_dtype=_dtype_for_device(),
49
  device_map="auto",
50
  )
51
  self.rerank_model.eval()
52
 
53
- # Vision generator
54
- logger.info("Loading Vision model: %s", settings.gen_model_id)
55
  self.gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
56
  settings.gen_model_id,
57
- torch_dtype=_dtype_for_device(),
58
  device_map="auto",
59
  )
60
  self.gen_model.eval()
61
-
62
  self.gen_processor = AutoProcessor.from_pretrained(settings.gen_model_id)
 
 
 
 
 
 
6
  Qwen2VLForConditionalGeneration,
7
  AutoProcessor,
8
  )
 
9
  from gliner import GLiNER
10
+
11
  from rag.config import Settings
12
  from rag.logging_utils import get_logger
13
 
14
  logger = get_logger(__name__)
15
 
 
 
 
 
 
 
 
16
  class Models:
17
  def __init__(self, settings: Settings):
18
  self.settings = settings
19
 
20
+ # Router CPU
21
+ logger.info("🧠 Loading GLiNER router on CPU: %s", settings.router_model_id)
22
+ self.router_model = GLiNER.from_pretrained(settings.router_model_id).to("cpu")
23
+ self.router_model.eval()
24
 
25
+ # Embedding
26
+ logger.info("🔹 Loading embedder: %s", settings.embed_model_id)
27
  self.embed_tokenizer = AutoTokenizer.from_pretrained(settings.embed_model_id, trust_remote_code=False)
28
  self.embed_model = AutoModel.from_pretrained(
29
  settings.embed_model_id,
30
  trust_remote_code=False,
31
+ torch_dtype=torch.bfloat16,
32
  device_map="auto",
33
  )
34
  self.embed_model.eval()
35
 
36
  # Reranker
37
+ logger.info("⚖️ Loading reranker: %s", settings.rerank_model_id)
38
  self.rerank_tokenizer = AutoTokenizer.from_pretrained(settings.rerank_model_id)
39
  self.rerank_model = AutoModelForSequenceClassification.from_pretrained(
40
  settings.rerank_model_id,
41
+ torch_dtype=torch.bfloat16,
42
  device_map="auto",
43
  )
44
  self.rerank_model.eval()
45
 
46
+ # Vision
47
+ logger.info("👁️ Loading vision model: %s", settings.gen_model_id)
48
  self.gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
49
  settings.gen_model_id,
50
+ torch_dtype=torch.bfloat16,
51
  device_map="auto",
52
  )
53
  self.gen_model.eval()
 
54
  self.gen_processor = AutoProcessor.from_pretrained(settings.gen_model_id)
55
+
56
+ def load_models(settings: Settings | None = None) -> Models:
57
+ if settings is None:
58
+ settings = Settings()
59
+ return Models(settings)