DivingFox commited on
Commit
45aa061
·
verified ·
1 Parent(s): d40b2b5

Update src/streamlit_app.py

Browse files

fix app to use local dataset

Files changed (1) hide show
  1. src/streamlit_app.py +56 -30
src/streamlit_app.py CHANGED
@@ -21,6 +21,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
21
  from PIL import Image
22
  from torchvision import transforms
23
  from io import BytesIO
 
 
24
 
25
  # streamlit_config_dir = "/tmp/.streamlit"
26
  # st.sidebar.write("Streamlit config dir exists:", os.path.exists(streamlit_config_dir))
@@ -29,58 +31,76 @@ from io import BytesIO
29
  torch.classes.__path__ = []
30
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  @st.cache_resource
33
  def load_caption_model():
34
  # load medicap
35
  ckpt_name = 'aehrc/medicap'
36
- medicap = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True)
 
 
 
 
 
37
  medicap = medicap.to(device)
38
  medicap.eval()
39
 
40
  # transform image
41
- image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name)
42
- medicap_transforms = transforms.Compose(
43
- [
44
- transforms.Resize(size=image_processor.size['shortest_edge']),
45
- transforms.CenterCrop(size=[
46
- image_processor.size['shortest_edge'],
47
- image_processor.size['shortest_edge'],
48
- ]
49
- ),
50
- transforms.ToTensor(),
51
- transforms.Normalize(
52
- mean=image_processor.image_mean,
53
- std=image_processor.image_std,
54
- ),
55
- ]
56
- )
57
 
58
  # tokenizer
59
- medicap_tokenizer = transformers.AutoTokenizer.from_pretrained(ckpt_name)
60
 
61
  return medicap, medicap_transforms, medicap_tokenizer
62
 
63
  def generate_image_caption(image, model, transformer, tokenizer):
64
- image = transformer(image)
65
- image = image.unsqueeze(0)
66
  outputs = model.generate(
67
  pixel_values=image.to(device),
68
  bos_token_id=tokenizer.bos_token_id,
69
  eos_token_id=tokenizer.eos_token_id,
70
  pad_token_id=tokenizer.pad_token_id,
71
- return_dict_in_generate=True,
72
- use_cache=True,
73
- max_length=256,
74
  num_beams=4,
75
  output_attentions=False
76
  )
77
- return tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
78
 
79
  @st.cache_resource
80
  def load_qa_model():
81
  model_name = "microsoft/BioGPT-Large-PubMedQA"
82
- biogpt_tokenizer = AutoTokenizer.from_pretrained(model_name)
83
- biogpt = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
84
  biogpt = biogpt.to(device)
85
  biogpt.eval()
86
 
@@ -93,7 +113,7 @@ def generate_answer(description, question, model, tokenizer):
93
 
94
  generated_output = model.generate(
95
  input_ids,
96
- max_new_tokens=100, # Max new tokens for the bot's response
97
  )
98
 
99
  response = tokenizer.decode(generated_output[0], skip_special_tokens=True)
@@ -103,10 +123,16 @@ def generate_answer(description, question, model, tokenizer):
103
  st.set_page_config(page_title="Image Caption + QA", layout="centered")
104
  st.title("🖼️ Caption-Based Question Answering")
105
 
106
- uploaded_file = st.file_uploader("Choose Image", type = ["jpg", "jpeg", "png"])
107
- if uploaded_file is not None:
108
- img = Image.open(uploaded_file)
 
 
 
 
109
  st.image(img)
 
 
110
 
111
  # image description
112
  medicap, medicap_transforms, medicap_tokenizer = load_caption_model()
 
21
  from PIL import Image
22
  from torchvision import transforms
23
  from io import BytesIO
24
+ from pathlib import Path
25
+ import pandas as pd
26
 
27
  # streamlit_config_dir = "/tmp/.streamlit"
28
  # st.sidebar.write("Streamlit config dir exists:", os.path.exists(streamlit_config_dir))
 
31
  torch.classes.__path__ = []
32
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
 
34
+ # load dataframe
35
+ data = {
36
+ "path": [
37
+ 'test/s55512076.jpg',
38
+ 'test/s55786650.jpg',
39
+ 'test/s56188631.jpg',
40
+ 'test/s53690114.jpg',
41
+ 'test/s52070116.jpg'],
42
+
43
+ "text": ['Comparison is made to prior study performed a day earlier. Lines and tubes are in unchanged standard position. Multifocal consolidations in the right upper and lower lobes bilaterally left greater than right are unchanged. Severe cardiomegaly is stable. There are no new lung abnormalities. Probably small right pleural effusion is unchanged.',
44
+ 'As compared to the previous radiograph, there is no relevant change. The monitoring and support devices are constant. Low lung volumes, borderline size of the cardiac silhouette. Mild pulmonary edema. Moderate retrocardiac atelectasis. No evidence of pneumonia.',
45
+ 'AP chest compared to ___ through ___. Elevation of the right lung base and hemidiaphragm has been pronounced since at least ___, accounting for atelectasis at the lung base. The right upper lung and the entire left lung are clear and the left lung is hyperinflated suggesting airway obstruction or emphysema. Heart is normal size. There is no pneumonia or pulmonary edema. No pleural effusion or pneumothorax.',
46
+ 'Compared to prior study there is no significant interval change.',
47
+ 'In comparison to prior radiograph of 1 day earlier, there has been improved aeration at both lung bases. No other relevant change since recent study.'],
48
+ }
49
+
50
+ # prepare data
51
+ mimic_df_test = pd.DataFrame.from_dict(data)
52
+
53
+ def load_images(path):
54
+ img = Image.open(path)
55
+ img = img.convert('RGB')
56
+ return img
57
+
58
  @st.cache_resource
59
  def load_caption_model():
60
  # load medicap
61
  ckpt_name = 'aehrc/medicap'
62
+
63
+ local_folder = "model2/"
64
+ if os.path.exists(local_folder):
65
+ medicap = transformers.AutoModel.from_pretrained(local_folder, trust_remote_code=True)
66
+ else:
67
+ medicap = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True)
68
  medicap = medicap.to(device)
69
  medicap.eval()
70
 
71
  # transform image
72
+ medicap_transforms = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # tokenizer
75
+ medicap_tokenizer = transformers.GPT2Tokenizer.from_pretrained(ckpt_name)
76
 
77
  return medicap, medicap_transforms, medicap_tokenizer
78
 
79
  def generate_image_caption(image, model, transformer, tokenizer):
80
+ image = transformer(image, return_tensors="pt")
81
+ image = image["pixel_values"]
82
  outputs = model.generate(
83
  pixel_values=image.to(device),
84
  bos_token_id=tokenizer.bos_token_id,
85
  eos_token_id=tokenizer.eos_token_id,
86
  pad_token_id=tokenizer.pad_token_id,
87
+ max_length=128,
 
 
88
  num_beams=4,
89
  output_attentions=False
90
  )
91
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
92
 
93
  @st.cache_resource
94
  def load_qa_model():
95
  model_name = "microsoft/BioGPT-Large-PubMedQA"
96
+
97
+ local_folder = "BioGPT-Large-PubMedQA/"
98
+ if os.path.exists(local_folder):
99
+ biogpt_tokenizer = AutoTokenizer.from_pretrained(local_folder)
100
+ biogpt = AutoModelForCausalLM.from_pretrained(local_folder)
101
+ else:
102
+ biogpt_tokenizer = AutoTokenizer.from_pretrained(model_name)
103
+ biogpt = AutoModelForCausalLM.from_pretrained(model_name)
104
  biogpt = biogpt.to(device)
105
  biogpt.eval()
106
 
 
113
 
114
  generated_output = model.generate(
115
  input_ids,
116
+ max_new_tokens=128, # Max new tokens for the bot's response
117
  )
118
 
119
  response = tokenizer.decode(generated_output[0], skip_special_tokens=True)
 
123
  st.set_page_config(page_title="Image Caption + QA", layout="centered")
124
  st.title("🖼️ Caption-Based Question Answering")
125
 
126
+ # Dropdown list
127
+ options = range(len(mimic_df_test))
128
+ choice = st.selectbox("Choose an action:", options)
129
+ if choice is not None:
130
+ data = mimic_df_test.iloc[choice]
131
+ label = data['text']
132
+ img = Image.open(Path(data['path']))
133
  st.image(img)
134
+ st.subheader("📝 Original Description")
135
+ st.info(label)
136
 
137
  # image description
138
  medicap, medicap_transforms, medicap_tokenizer = load_caption_model()