Faisal commited on
Commit
adaeccf
Β·
1 Parent(s): 008f93e
Files changed (2) hide show
  1. app.py +5 -14
  2. requirements.txt +4 -6
app.py CHANGED
@@ -5,26 +5,17 @@ import torch
5
  import requests
6
 
7
  # ----------------------------
8
- # MODEL LOADING (MedVLM-R1) - CPU Compatible
9
  # ----------------------------
10
  MODEL_PATH = 'JZPeterPan/MedVLM-R1'
11
 
12
- # Check if CUDA is available, otherwise use CPU
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"Using device: {device}")
15
-
16
  model = Qwen2VLForConditionalGeneration.from_pretrained(
17
  MODEL_PATH,
18
- torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
19
- device_map="auto" if device == "cuda" else None,
20
- low_cpu_mem_usage=True if device == "cpu" else False,
21
  )
22
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
23
 
24
- # Move model to device
25
- if device == "cpu":
26
- model = model.to(device)
27
-
28
  temp_generation_config = GenerationConfig(
29
  max_new_tokens=1024,
30
  do_sample=False,
@@ -80,7 +71,7 @@ def process_pipeline(image, user_question):
80
  videos=video_inputs,
81
  padding=True,
82
  return_tensors="pt",
83
- ).to(device)
84
 
85
  # Generate output from MedVLM
86
  generated_ids = model.generate(
@@ -162,4 +153,4 @@ with gr.Blocks(title="Brain MRI QA") as demo:
162
  )
163
 
164
  if __name__ == "__main__":
165
- demo.launch()
 
5
  import requests
6
 
7
  # ----------------------------
8
+ # MODEL LOADING (MedVLM-R1)
9
  # ----------------------------
10
  MODEL_PATH = 'JZPeterPan/MedVLM-R1'
11
 
 
 
 
 
12
  model = Qwen2VLForConditionalGeneration.from_pretrained(
13
  MODEL_PATH,
14
+ torch_dtype=torch.bfloat16,
15
+ device_map="auto",
 
16
  )
17
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
18
 
 
 
 
 
19
  temp_generation_config = GenerationConfig(
20
  max_new_tokens=1024,
21
  do_sample=False,
 
71
  videos=video_inputs,
72
  padding=True,
73
  return_tensors="pt",
74
+ ).to("cuda")
75
 
76
  # Generate output from MedVLM
77
  generated_ids = model.generate(
 
153
  )
154
 
155
  if __name__ == "__main__":
156
+ demo.launch()
requirements.txt CHANGED
@@ -1,15 +1,13 @@
1
  gradio==5.42.0
2
- transformers==4.36.0
3
- --find-links https://download.pytorch.org/whl/torch_stable.html
4
- torch==2.1.0+cpu
5
- torchvision==0.16.0+cpu
6
- torchaudio==2.1.0+cpu
7
  requests>=2.31.0
8
  Pillow>=10.0.0
9
  accelerate>=0.20.0
10
  safetensors>=0.3.0
11
  tokenizers>=0.15.0
12
- numpy<2.0.0
13
  scipy>=1.10.0
14
  qwen-vl-utils
15
  ipython>=8.0.0
 
1
  gradio==5.42.0
2
+ transformers>=4.40.0
3
+ torch>=2.0.0
4
+ torchvision>=0.15.0
 
 
5
  requests>=2.31.0
6
  Pillow>=10.0.0
7
  accelerate>=0.20.0
8
  safetensors>=0.3.0
9
  tokenizers>=0.15.0
10
+ numpy>=1.24.0
11
  scipy>=1.10.0
12
  qwen-vl-utils
13
  ipython>=8.0.0