Update model.py
Browse files
model.py
CHANGED
|
@@ -8,14 +8,12 @@ class ChestGPTDemo:
|
|
| 8 |
def __init__(self, device=None):
|
| 9 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
self.vit = ViTEncoder().to(self.device).eval()
|
| 11 |
-
|
| 12 |
-
self.
|
| 13 |
-
self.lm = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b-instruct")
|
| 14 |
-
self.lm = self.lm.to(self.device).eval()
|
| 15 |
|
| 16 |
self.prompt = (
|
| 17 |
-
"[radiology] Please describe this chest X-ray
|
| 18 |
-
"
|
| 19 |
)
|
| 20 |
|
| 21 |
def process_image(self, img: Image.Image):
|
|
@@ -23,18 +21,10 @@ class ChestGPTDemo:
|
|
| 23 |
transforms.Resize((224, 224)),
|
| 24 |
transforms.ToTensor()
|
| 25 |
])
|
| 26 |
-
|
| 27 |
-
return tensor
|
| 28 |
|
| 29 |
def predict(self, img: Image.Image):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
# Use first 10 features from ViT just as a mock
|
| 35 |
-
prompt = self.prompt + "\n[image_features]: " + ", ".join([f"{x:.3f}" for x in vit_feat[0][:10]])
|
| 36 |
-
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 37 |
-
|
| 38 |
-
output_ids = self.lm.generate(**inputs, max_new_tokens=100)
|
| 39 |
-
result = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 40 |
-
return result
|
|
|
|
| 8 |
def __init__(self, device=None):
|
| 9 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
self.vit = ViTEncoder().to(self.device).eval()
|
| 11 |
+
self.tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
|
| 12 |
+
self.lm = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b").to(self.device).eval()
|
|
|
|
|
|
|
| 13 |
|
| 14 |
self.prompt = (
|
| 15 |
+
"[radiology] Please describe this chest X-ray. "
|
| 16 |
+
"Mention global diseases and local findings if visible."
|
| 17 |
)
|
| 18 |
|
| 19 |
def process_image(self, img: Image.Image):
|
|
|
|
| 21 |
transforms.Resize((224, 224)),
|
| 22 |
transforms.ToTensor()
|
| 23 |
])
|
| 24 |
+
return transform(img.convert("RGB")).unsqueeze(0).to(self.device)
|
|
|
|
| 25 |
|
| 26 |
def predict(self, img: Image.Image):
|
| 27 |
+
_ = self.vit(self.process_image(img)) # placeholder for visual reasoning
|
| 28 |
+
input_ids = self.tokenizer(self.prompt, return_tensors="pt").input_ids.to(self.device)
|
| 29 |
+
output = self.lm.generate(input_ids, max_new_tokens=100)
|
| 30 |
+
return self.tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|