bdtimuhammad commited on
Commit
ff3a836
Β·
verified Β·
1 Parent(s): 51ed9c7

Update loader.py

Browse files
Files changed (1) hide show
  1. loader.py +18 -80
loader.py CHANGED
@@ -1,90 +1,28 @@
1
  import torch
2
  import gc
3
  import open_clip
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig
5
- import os
6
 
7
  class ModelLoader:
8
  def __init__(self):
9
- self.medgemma_model = None
10
- self.medgemma_tokenizer = None
11
-
12
- self.biomedclip_model = None
13
- self.biomedclip_preprocess = None
14
- self.biomedclip_tokenizer = None
15
-
16
- self.maira2_model = None
17
- self.maira2_processor = None
18
-
19
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
- self.hf_token = os.getenv("HF_TOKEN", "")
21
-
22
- def init_startup_models(self):
23
- """Loads MedGemma and BiomedCLIP into VRAM."""
24
- print("Pre-loading MedGemma 1.5 4B...")
25
- try:
26
- self.medgemma_tokenizer = AutoTokenizer.from_pretrained(
27
- "google/medgemma-1.5-4b-it", token=self.hf_token
28
- )
29
- bnb_config = BitsAndBytesConfig(
30
- load_in_4bit=True,
31
- bnb_4bit_compute_dtype=torch.float16
32
- )
33
- self.medgemma_model = AutoModelForCausalLM.from_pretrained(
34
- "google/medgemma-1.5-4b-it",
35
- token=self.hf_token,
36
- quantization_config=bnb_config,
37
- device_map="auto"
38
- )
39
- except Exception as e:
40
- print(f"Failed to load MedGemma: {e}")
41
-
42
- print("Pre-loading BiomedCLIP...")
43
- try:
44
- model, preprocess, _ = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
45
- self.biomedclip_tokenizer = open_clip.get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
46
- self.biomedclip_model = model.to(self.device).eval()
47
- self.biomedclip_preprocess = preprocess
48
- except Exception as e:
49
- print(f"Failed to load BiomedCLIP: {e}")
50
 
51
- def load_maira2_lazy(self):
52
- """Lazy loads MAIRA-2 to GPU."""
53
- if self.maira2_model is not None:
54
- return
55
-
56
- print("Lazy Loading MAIRA-2...")
57
- try:
58
- self.maira2_processor = AutoProcessor.from_pretrained("microsoft/maira-2", token=self.hf_token, trust_remote_code=True)
59
- bnb_config = BitsAndBytesConfig(
60
- load_in_4bit=True,
61
- bnb_4bit_compute_dtype=torch.float16
62
  )
63
- self.maira2_model = AutoModelForCausalLM.from_pretrained(
64
- "microsoft/maira-2",
65
- token=self.hf_token,
66
- quantization_config=bnb_config,
67
- device_map="auto",
68
- trust_remote_code=True
69
- )
70
- except Exception as e:
71
- print(f"Failed to load MAIRA-2: {e}")
72
 
73
  def clear_vram(self):
74
- """Strictly moves MAIRA-2 out of VRAM/memory to ensure T4 capacity."""
75
- if self.maira2_model is not None:
76
- print("Offloading MAIRA-2 to free VRAM...")
77
- del self.maira2_model
78
- self.maira2_model = None
79
- gc.collect()
80
- if torch.cuda.is_available():
81
- torch.cuda.empty_cache()
82
-
83
- def get_medgemma(self):
84
- return self.medgemma_model, self.medgemma_tokenizer
85
-
86
- def get_biomedclip(self):
87
- return self.biomedclip_model, self.biomedclip_preprocess, self.biomedclip_tokenizer
88
-
89
- def get_maira2(self):
90
- return self.maira2_model, self.maira2_processor
 
1
  import torch
2
  import gc
3
  import open_clip
 
 
4
 
5
  class ModelLoader:
6
  def __init__(self):
7
+ self.biomed_model = None
8
+ self.preprocess = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def load_biomed_clip(self):
11
+ """Universal Zero-Shot Auditor (BiomedCLIP)"""
12
+ if self.biomed_model is None:
13
+ print("πŸ”„ Loading BiomedCLIP Universal Auditor...")
14
+ model, _, preprocess = open_clip.create_model_and_transforms(
15
+ 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
 
 
 
 
 
16
  )
17
+ self.biomed_model = model.to("cuda").eval()
18
+ self.preprocess = preprocess
19
+ return self.biomed_model, self.preprocess
 
 
 
 
 
 
20
 
21
  def clear_vram(self):
22
+ """Safety flush to ensure Council stability on T4."""
23
+ gc.collect()
24
+ if torch.cuda.is_available():
25
+ torch.cuda.empty_cache()
26
+
27
+ # πŸ‘‡ THIS IS THE CRUCIAL LINE THAT WAS MISSING πŸ‘‡
28
+ loader = ModelLoader()