safiaa02 commited on
Commit
382d402
·
verified ·
1 Parent(s): 1c3f102

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -19
model.py CHANGED
@@ -8,14 +8,12 @@ class ChestGPTDemo:
8
  def __init__(self, device=None):
9
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
10
  self.vit = ViTEncoder().to(self.device).eval()
11
-
12
- self.tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
13
- self.lm = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b-instruct")
14
- self.lm = self.lm.to(self.device).eval()
15
 
16
  self.prompt = (
17
- "[radiology] Please describe this chest X-ray in detail. "
18
- "List global diseases and any local findings with locations."
19
  )
20
 
21
  def process_image(self, img: Image.Image):
@@ -23,18 +21,10 @@ class ChestGPTDemo:
23
  transforms.Resize((224, 224)),
24
  transforms.ToTensor()
25
  ])
26
- tensor = transform(img.convert("RGB")).unsqueeze(0).to(self.device)
27
- return tensor
28
 
29
  def predict(self, img: Image.Image):
30
- img_tensor = self.process_image(img)
31
- with torch.no_grad():
32
- vit_feat = self.vit(img_tensor)
33
-
34
- # Use first 10 features from ViT just as a mock
35
- prompt = self.prompt + "\n[image_features]: " + ", ".join([f"{x:.3f}" for x in vit_feat[0][:10]])
36
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
37
-
38
- output_ids = self.lm.generate(**inputs, max_new_tokens=100)
39
- result = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
- return result
 
8
  def __init__(self, device=None):
9
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
10
  self.vit = ViTEncoder().to(self.device).eval()
11
+ self.tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
12
+ self.lm = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b").to(self.device).eval()
 
 
13
 
14
  self.prompt = (
15
+ "[radiology] Please describe this chest X-ray. "
16
+ "Mention global diseases and local findings if visible."
17
  )
18
 
19
  def process_image(self, img: Image.Image):
 
21
  transforms.Resize((224, 224)),
22
  transforms.ToTensor()
23
  ])
24
+ return transform(img.convert("RGB")).unsqueeze(0).to(self.device)
 
25
 
26
  def predict(self, img: Image.Image):
27
+ _ = self.vit(self.process_image(img)) # placeholder for visual reasoning
28
+ input_ids = self.tokenizer(self.prompt, return_tensors="pt").input_ids.to(self.device)
29
+ output = self.lm.generate(input_ids, max_new_tokens=100)
30
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)