safiaa02 commited on
Commit
a34d86a
·
verified ·
1 Parent(s): fca50d4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +22 -15
model.py CHANGED
@@ -1,30 +1,37 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from torchvision import transforms
4
- from vit_encoder import ViTEncoder
5
  from PIL import Image
6
 
7
  class ChestGPTDemo:
8
  def __init__(self, device=None):
9
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
10
- self.vit = ViTEncoder().to(self.device).eval()
11
- self.tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
12
- self.lm = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b").to(self.device).eval()
13
 
 
 
 
 
 
14
  self.prompt = (
15
- "[radiology] Please describe this chest X-ray. "
16
- "Mention global diseases and local findings if visible."
 
 
 
 
 
17
  )
18
 
19
  def process_image(self, img: Image.Image):
20
- transform = transforms.Compose([
21
- transforms.Resize((224, 224)),
22
- transforms.ToTensor()
23
- ])
24
- return transform(img.convert("RGB")).unsqueeze(0).to(self.device)
25
 
26
  def predict(self, img: Image.Image):
27
- _ = self.vit(self.process_image(img)) # placeholder for visual reasoning
28
- input_ids = self.tokenizer(self.prompt, return_tensors="pt").input_ids.to(self.device)
29
- output = self.lm.generate(input_ids, max_new_tokens=100)
30
- return self.tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from torchvision import transforms
 
4
  from PIL import Image
5
 
6
  class ChestGPTDemo:
7
  def __init__(self, device=None):
8
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
9
 
10
+ # Load clinical GPT-2 model
11
+ self.tokenizer = AutoTokenizer.from_pretrained("mrm8488/GPT-2-finetuned-clinical-notes")
12
+ self.lm = AutoModelForCausalLM.from_pretrained("mrm8488/GPT-2-finetuned-clinical-notes").to(self.device).eval()
13
+
14
+ # Few-shot prompt to guide generation
15
  self.prompt = (
16
+ "[radiology] Example 1:\n"
17
+ "Global Disease: Cardiomegaly\n"
18
+ "Local Finding: Patchy opacity in right lower lobe (BBox: 50,60,120,150)\n\n"
19
+ "[radiology] Example 2:\n"
20
+ "Global Disease: Normal\n"
21
+ "Local Finding: No abnormalities detected\n\n"
22
+ "[radiology] Please describe this chest X-ray. Mention global diseases and local findings if visible.\n"
23
  )
24
 
25
  def process_image(self, img: Image.Image):
26
+ # Placeholder for image features – will add ViT later
27
+ return None
 
 
 
28
 
29
  def predict(self, img: Image.Image):
30
+ _ = self.process_image(img) # skip vision for now
31
+ inputs = self.tokenizer(self.prompt, return_tensors="pt", padding=True).to(self.device)
32
+ outputs = self.lm.generate(
33
+ **inputs,
34
+ max_new_tokens=100,
35
+ pad_token_id=self.tokenizer.eos_token_id
36
+ )
37
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)