safiaa02 commited on
Commit
045c6f2
·
verified ·
1 Parent(s): 7131543

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +38 -0
model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from torchvision import transforms
4
+ from vit_encoder import ViTEncoder
5
+ from PIL import Image
6
+
7
+ class ChestGPTDemo:
8
+ def __init__(self, device="cpu"):
9
+ self.device = device
10
+ self.vit = ViTEncoder().to(device).eval()
11
+ self.tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
12
+ self.lm = AutoModelForCausalLM.from_pretrained(
13
+ "tiiuae/falcon-7b-instruct",
14
+ device_map="auto",
15
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
16
+ )
17
+ self.prompt = (
18
+ "[radiology] Please describe this chest X-ray in detail. "
19
+ "List global diseases and any local findings with locations."
20
+ )
21
+
22
+ def process_image(self, img: Image.Image):
23
+ transform = transforms.Compose([
24
+ transforms.Resize((224, 224)),
25
+ transforms.ToTensor()
26
+ ])
27
+ tensor = transform(img.convert("RGB")).unsqueeze(0).to(self.device)
28
+ return tensor
29
+
30
+ def predict(self, img: Image.Image):
31
+ img_tensor = self.process_image(img)
32
+ with torch.no_grad():
33
+ vit_feat = self.vit(img_tensor)
34
+ # Shorten token list for demo
35
+ prompt = self.prompt + "\n[image_features]: " + ", ".join([f"{x:.3f}" for x in vit_feat[0][:10]])
36
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
37
+ out = self.lm.generate(input_ids, max_new_tokens=100)
38
+ return self.tokenizer.decode(out[0], skip_special_tokens=True)