Sayed121 commited on
Commit
82cfbbf
·
1 Parent(s): 43ea0ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -20
app.py CHANGED
@@ -3,49 +3,80 @@
3
 
4
  # In[5]:
5
 
 
6
  import streamlit as st
7
  from PIL import Image
8
  import torch
9
- from transformers import BlipProcessor, BlipForQuestionAnswering, BlipImageProcessor
 
 
 
 
 
 
 
 
 
10
 
11
- # Load model and processors
12
  text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
13
  image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base")
14
- model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
 
15
 
16
- # Function to preprocess image
17
  def preprocess_image(image):
18
- # Resize image to an appropriate size
19
- image = image.resize((256, 256))
20
- image_encoding = image_processor(image, return_tensors="pt")
 
 
 
 
21
  return image_encoding["pixel_values"][0]
22
 
23
- # Function to preprocess text
24
  def preprocess_text(text, max_length=32):
25
- encoding = text_processor(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
 
 
 
 
 
 
 
 
 
26
  for k, v in encoding.items():
27
  encoding[k] = v.squeeze()
28
  return encoding
29
 
30
- # Function to make predictions
31
  def predict(image, question):
 
 
 
 
 
 
 
 
 
 
 
 
32
  model.eval()
33
- with torch.no_grad():
34
- pixel_values = preprocess_image(image).unsqueeze(0)
35
- encoding = preprocess_text(question)
36
- outputs = model(pixel_values=pixel_values, input_ids=encoding['input_ids'].unsqueeze(0))
37
 
38
- prediction_result = text_processor.decode(outputs[0][0], skip_special_tokens=True)
39
  return prediction_result
40
 
41
- # Streamlit app
42
  def main():
 
43
  st.set_page_config(
44
  page_title="PathoAgent",
45
  page_icon=":microscope:",
46
  layout="wide"
47
  )
48
 
 
49
  st.title(":microscope: PathoAgent")
50
  st.markdown(
51
  """
@@ -67,10 +98,11 @@ def main():
67
  """,
68
  unsafe_allow_html=True
69
  )
70
-
71
  st.markdown("<div class='header'><h3 class='subheader'>Medical Image Analysis for Pathology</h3></div>", unsafe_allow_html=True)
72
  st.markdown("<hr style='border: 1px solid #ddd;'>", unsafe_allow_html=True)
73
 
 
74
  nav_option = st.sidebar.radio("Navigation", ["Home", "Sample Images", "Upload Image"])
75
 
76
  if nav_option == "Home":
@@ -80,7 +112,80 @@ def main():
80
  elif nav_option == "Upload Image":
81
  upload_image()
82
 
83
- # Other functions remain unchanged...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- if _name_ == "_main_":
86
- main()
 
3
 
4
  # In[5]:
5
 
6
+
7
  import streamlit as st
8
  from PIL import Image
9
  import torch
10
+ import requests
11
+ from transformers import BlipProcessor, BlipForQuestionAnswering,BlipImageProcessor, AutoProcessor
12
+ from transformers import BlipConfig
13
+ from datasets import load_dataset
14
+ from torch.utils.data import DataLoader
15
+ from tqdm.notebook import tqdm
16
+
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from IPython.display import display
20
 
 
21
  text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
22
  image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base")
23
+ model = BlipForQuestionAnswering.from_pretrained(r"blip_model_v2_epo89" )
24
+
25
 
 
26
  def preprocess_image(image):
27
+ # Your image preprocessing logic here...
28
+ # Example: Resize image to 128x128 pixels
29
+ image = image.resize((128, 128))
30
+ image_encoding = image_processor(image,
31
+ do_resize=True,
32
+ size=(128, 128),
33
+ return_tensors="pt")
34
  return image_encoding["pixel_values"][0]
35
 
 
36
  def preprocess_text(text, max_length=32):
37
+ # Your text preprocessing logic here...
38
+ encoding = text_processor(
39
+ None,
40
+ text,
41
+ padding="max_length",
42
+ truncation=True,
43
+ max_length=max_length,
44
+ return_tensors="pt"
45
+ )
46
+
47
  for k, v in encoding.items():
48
  encoding[k] = v.squeeze()
49
  return encoding
50
 
 
51
  def predict(image, question):
52
+ # Preprocess image
53
+ pixel_values = preprocess_image(image).unsqueeze(0)
54
+
55
+ # Preprocess text
56
+ encoding = preprocess_text(question)
57
+
58
+ # Print shapes for debugging
59
+ #print("Pixel Values Shape:", pixel_values.shape)
60
+ #print("Input IDs Shape:", encoding['input_ids'].unsqueeze(0).shape)
61
+
62
+ # Perform prediction using your model
63
+ # Example: Replace this with your actual prediction logic
64
  model.eval()
65
+ outputs = model.generate(pixel_values=pixel_values, input_ids=encoding['input_ids'].unsqueeze(0))
66
+
67
+ prediction_result = text_processor.decode(outputs[0], skip_special_tokens=True)
 
68
 
 
69
  return prediction_result
70
 
 
71
  def main():
72
+ # Set page title and configure page layout
73
  st.set_page_config(
74
  page_title="PathoAgent",
75
  page_icon=":microscope:",
76
  layout="wide"
77
  )
78
 
79
+ # Add header with styled text
80
  st.title(":microscope: PathoAgent")
81
  st.markdown(
82
  """
 
98
  """,
99
  unsafe_allow_html=True
100
  )
101
+
102
  st.markdown("<div class='header'><h3 class='subheader'>Medical Image Analysis for Pathology</h3></div>", unsafe_allow_html=True)
103
  st.markdown("<hr style='border: 1px solid #ddd;'>", unsafe_allow_html=True)
104
 
105
+ # Navigation bar
106
  nav_option = st.sidebar.radio("Navigation", ["Home", "Sample Images", "Upload Image"])
107
 
108
  if nav_option == "Home":
 
112
  elif nav_option == "Upload Image":
113
  upload_image()
114
 
115
+ def home():
116
+ st.header("Welcome to PathoAgent!")
117
+ st.write(
118
+ "PathoAgent is an AI-powered medical image analysis tool designed for pathology diagnostics. "
119
+ "It empowers healthcare professionals with accurate predictions and insights from medical images. "
120
+ "Choose an option from the sidebar to get started."
121
+ )
122
+
123
+ st.header("About PathoAgent")
124
+ st.write(
125
+ "PathoAgent leverages advanced VQA algorithms to analyze medical images related to pathology. "
126
+ "Whether you want to upload your own images or use our sample images, PathoAgent provides predictions for pathology-related questions. "
127
+ "Explore the features and capabilities to enhance your diagnostic process."
128
+ )
129
+
130
+ def sample_images():
131
+ st.header("Sample Images")
132
+
133
+ # Sample images
134
+ example_image = {
135
+ "Sample 1": "img_0002.jpg",
136
+ }
137
+
138
+ # Button to load sample images
139
+ if st.button("Load Example Images"):
140
+
141
+ sample_image = Image.open(example_image).convert('RGB')
142
+ st.image(sample_image, caption=f"Example Image", use_column_width=True)
143
+
144
+ # Text input for each sample image
145
+ text_input = st.text_area(f"Input Question:")
146
+
147
+ # Predict button for each sample image
148
+ if st.button(f"Predict"):
149
+ if text_input:
150
+ # Perform prediction
151
+ prediction_result = predict(sample_image, text_input)
152
+
153
+ # Display input text
154
+ st.subheader(f"Input Question:")
155
+ st.write(text_input)
156
+
157
+ # Display prediction result
158
+ st.subheader(f"Prediction Result:")
159
+ st.write(prediction_result)
160
+
161
+ def upload_image():
162
+ st.header("Upload Image")
163
+
164
+ # Image upload
165
+ uploaded_file = st.file_uploader("Choose a file", type=["jpg", "png", "jpeg"])
166
+
167
+ # Text input
168
+ st.subheader("Input Question")
169
+ text_input = st.text_area("Enter text here:")
170
+
171
+ # Display uploaded image
172
+ if uploaded_file is not None:
173
+ image = Image.open(uploaded_file).convert('RGB')
174
+ st.image(image, caption="Uploaded Image.", use_column_width=True)
175
+
176
+ # Predict button
177
+ if st.button("Predict"):
178
+ if uploaded_file is not None and text_input:
179
+ # Perform prediction
180
+ prediction_result = predict(image, text_input)
181
+
182
+ # Display input text
183
+ st.subheader("Input Question:")
184
+ st.write(text_input)
185
+
186
+ # Display prediction result
187
+ st.subheader("Prediction Result:")
188
+ st.write(prediction_result)
189
 
190
+ if __name__ == "__main__":
191
+ main()