thlinhares commited on
Commit
5ea6c11
·
1 Parent(s): 9176d4d

remove cuda

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app2.py +48 -0
app.py CHANGED
@@ -24,7 +24,7 @@ def generate(images, prompt, processor, model, device, dtype, generation_config)
24
 
25
  def main():
26
  # step 1: Setup constant
27
- device = "cuda"
28
  dtype = torch.float16
29
 
30
  # step 2: Load Processor and Model
 
24
 
25
  def main():
26
  # step 1: Setup constant
27
+ #device = "cuda"
28
  dtype = torch.float16
29
 
30
  # step 2: Load Processor and Model
app2.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ import torch
4
+ from PIL import Image
5
+ import streamlit as st
6
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
7
+
8
+ def download_image(url):
9
+ resp = requests.get(url)
10
+ resp.raise_for_status()
11
+ return Image.open(io.BytesIO(resp.content)).convert("RGB")
12
+
13
+ def generate(images, prompt, processor, model, device, dtype, generation_config):
14
+ inputs = processor(
15
+ images=images[:2], text=f" USER: <s>{prompt} ASSISTANT: <s>", return_tensors="pt"
16
+ ).to(device=device, dtype=dtype)
17
+ output = model.generate(**inputs, generation_config=generation_config)[0]
18
+ response = processor.tokenizer.decode(output, skip_special_tokens=True)
19
+ return response
20
+
21
+ def main():
22
+ st.title("Medical Image Analysis")
23
+
24
+ device = "cuda"
25
+ dtype = torch.float16
26
+
27
+ processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
28
+ generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ "StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True
31
+ ).to(device)
32
+
33
+ image_path = "https://upload.wikimedia.org/wikipedia/commons/3/3b/Pleural_effusion-Metastatic_breast_carcinoma_Case_166_%285477628658%29.jpg"
34
+ images = [download_image(image_path)]
35
+
36
+ anatomies = [
37
+ "Airway", "Breathing", "Cardiac", "Diaphragm",
38
+ "Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, and pacemakers)"
39
+ ]
40
+
41
+ for anatomy in anatomies:
42
+ prompt = f'Describe "{anatomy}"'
43
+ response = generate(images, prompt, processor, model, device, dtype, generation_config)
44
+ st.subheader(f"Findings for [{anatomy}]:")
45
+ st.write(response)
46
+
47
+ if __name__ == '__main__':
48
+ main()