ZedLow commited on
Commit
ee3f04c
·
verified ·
1 Parent(s): 82f0b9f

Create models.py

Browse files
Files changed (1) hide show
  1. rag/models.py +62 -0
rag/models.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ AutoModel,
4
+ AutoTokenizer,
5
+ AutoModelForSequenceClassification,
6
+ Qwen2VLForConditionalGeneration,
7
+ AutoProcessor,
8
+ )
9
+
10
+ from gliner import GLiNER
11
+ from rag.config import Settings
12
+ from rag.logging_utils import get_logger
13
+
14
+ logger = get_logger(__name__)
15
+
16
+ def _has_cuda() -> bool:
17
+ return torch.cuda.is_available()
18
+
19
+ def _dtype_for_device():
20
+ # BF16 is great on recent GPUs; on CPU it's often slower / unsupported in ops
21
+ return torch.bfloat16 if _has_cuda() else torch.float32
22
+
23
+ class Models:
24
+ def __init__(self, settings: Settings):
25
+ self.settings = settings
26
+
27
+ # Router (CPU)
28
+ logger.info("Loading Router (GLiNER) on CPU: %s", settings.router_model_id)
29
+ self.router = GLiNER.from_pretrained(settings.router_model_id).to("cpu")
30
+ self.router.eval()
31
+
32
+ # Embedder
33
+ logger.info("Loading Embedder: %s", settings.embed_model_id)
34
+ self.embed_tokenizer = AutoTokenizer.from_pretrained(settings.embed_model_id, trust_remote_code=False)
35
+ self.embed_model = AutoModel.from_pretrained(
36
+ settings.embed_model_id,
37
+ trust_remote_code=False,
38
+ torch_dtype=_dtype_for_device(),
39
+ device_map="auto",
40
+ )
41
+ self.embed_model.eval()
42
+
43
+ # Reranker
44
+ logger.info("Loading Reranker: %s", settings.rerank_model_id)
45
+ self.rerank_tokenizer = AutoTokenizer.from_pretrained(settings.rerank_model_id)
46
+ self.rerank_model = AutoModelForSequenceClassification.from_pretrained(
47
+ settings.rerank_model_id,
48
+ torch_dtype=_dtype_for_device(),
49
+ device_map="auto",
50
+ )
51
+ self.rerank_model.eval()
52
+
53
+ # Vision generator
54
+ logger.info("Loading Vision model: %s", settings.gen_model_id)
55
+ self.gen_model = Qwen2VLForConditionalGeneration.from_pretrained(
56
+ settings.gen_model_id,
57
+ torch_dtype=_dtype_for_device(),
58
+ device_map="auto",
59
+ )
60
+ self.gen_model.eval()
61
+
62
+ self.gen_processor = AutoProcessor.from_pretrained(settings.gen_model_id)