Sayed121 commited on
Commit
43ea0ee
·
1 Parent(s): 1d0bb0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -125
app.py CHANGED
@@ -3,80 +3,49 @@
3
 
4
  # In[5]:
5
 
6
-
7
  import streamlit as st
8
  from PIL import Image
9
  import torch
10
- import requests
11
- from transformers import BlipProcessor, BlipForQuestionAnswering,BlipImageProcessor, AutoProcessor
12
- from transformers import BlipConfig
13
- from datasets import load_dataset
14
- from torch.utils.data import DataLoader
15
- from tqdm.notebook import tqdm
16
-
17
- import numpy as np
18
- import matplotlib.pyplot as plt
19
- from IPython.display import display
20
 
 
21
  text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
22
  image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base")
23
- model = BlipForQuestionAnswering.from_pretrained(r"blip_model_v2_epo89" )
24
-
25
 
 
26
  def preprocess_image(image):
27
- # Your image preprocessing logic here...
28
- # Example: Resize image to 128x128 pixels
29
- image = image.resize((128, 128))
30
- image_encoding = image_processor(image,
31
- do_resize=True,
32
- size=(128, 128),
33
- return_tensors="pt")
34
  return image_encoding["pixel_values"][0]
35
 
 
36
  def preprocess_text(text, max_length=32):
37
- # Your text preprocessing logic here...
38
- encoding = text_processor(
39
- None,
40
- text,
41
- padding="max_length",
42
- truncation=True,
43
- max_length=max_length,
44
- return_tensors="pt"
45
- )
46
-
47
  for k, v in encoding.items():
48
  encoding[k] = v.squeeze()
49
  return encoding
50
 
 
51
  def predict(image, question):
52
- # Preprocess image
53
- pixel_values = preprocess_image(image).unsqueeze(0)
54
-
55
- # Preprocess text
56
- encoding = preprocess_text(question)
57
-
58
- # Print shapes for debugging
59
- #print("Pixel Values Shape:", pixel_values.shape)
60
- #print("Input IDs Shape:", encoding['input_ids'].unsqueeze(0).shape)
61
-
62
- # Perform prediction using your model
63
- # Example: Replace this with your actual prediction logic
64
  model.eval()
65
- outputs = model.generate(pixel_values=pixel_values, input_ids=encoding['input_ids'].unsqueeze(0))
66
-
67
- prediction_result = text_processor.decode(outputs[0], skip_special_tokens=True)
 
68
 
 
69
  return prediction_result
70
 
 
71
  def main():
72
- # Set page title and configure page layout
73
  st.set_page_config(
74
  page_title="PathoAgent",
75
  page_icon=":microscope:",
76
  layout="wide"
77
  )
78
 
79
- # Add header with styled text
80
  st.title(":microscope: PathoAgent")
81
  st.markdown(
82
  """
@@ -98,11 +67,10 @@ def main():
98
  """,
99
  unsafe_allow_html=True
100
  )
101
-
102
  st.markdown("<div class='header'><h3 class='subheader'>Medical Image Analysis for Pathology</h3></div>", unsafe_allow_html=True)
103
  st.markdown("<hr style='border: 1px solid #ddd;'>", unsafe_allow_html=True)
104
 
105
- # Navigation bar
106
  nav_option = st.sidebar.radio("Navigation", ["Home", "Sample Images", "Upload Image"])
107
 
108
  if nav_option == "Home":
@@ -112,80 +80,7 @@ def main():
112
  elif nav_option == "Upload Image":
113
  upload_image()
114
 
115
- def home():
116
- st.header("Welcome to PathoAgent!")
117
- st.write(
118
- "PathoAgent is an AI-powered medical image analysis tool designed for pathology diagnostics. "
119
- "It empowers healthcare professionals with accurate predictions and insights from medical images. "
120
- "Choose an option from the sidebar to get started."
121
- )
122
-
123
- st.header("About PathoAgent")
124
- st.write(
125
- "PathoAgent leverages advanced VQA algorithms to analyze medical images related to pathology. "
126
- "Whether you want to upload your own images or use our sample images, PathoAgent provides predictions for pathology-related questions. "
127
- "Explore the features and capabilities to enhance your diagnostic process."
128
- )
129
-
130
- def sample_images():
131
- st.header("Sample Images")
132
-
133
- # Sample images
134
- example_image = {
135
- "Sample 1": "img_0002.jpg",
136
- }
137
-
138
- # Button to load sample images
139
- if st.button("Load Example Images"):
140
-
141
- sample_image = Image.open(example_image).convert('RGB')
142
- st.image(sample_image, caption=f"Example Image", use_column_width=True)
143
-
144
- # Text input for each sample image
145
- text_input = st.text_area(f"Input Question:")
146
-
147
- # Predict button for each sample image
148
- if st.button(f"Predict"):
149
- if text_input:
150
- # Perform prediction
151
- prediction_result = predict(sample_image, text_input)
152
-
153
- # Display input text
154
- st.subheader(f"Input Question:")
155
- st.write(text_input)
156
-
157
- # Display prediction result
158
- st.subheader(f"Prediction Result:")
159
- st.write(prediction_result)
160
-
161
- def upload_image():
162
- st.header("Upload Image")
163
-
164
- # Image upload
165
- uploaded_file = st.file_uploader("Choose a file", type=["jpg", "png", "jpeg"])
166
-
167
- # Text input
168
- st.subheader("Input Question")
169
- text_input = st.text_area("Enter text here:")
170
-
171
- # Display uploaded image
172
- if uploaded_file is not None:
173
- image = Image.open(uploaded_file).convert('RGB')
174
- st.image(image, caption="Uploaded Image.", use_column_width=True)
175
-
176
- # Predict button
177
- if st.button("Predict"):
178
- if uploaded_file is not None and text_input:
179
- # Perform prediction
180
- prediction_result = predict(image, text_input)
181
-
182
- # Display input text
183
- st.subheader("Input Question:")
184
- st.write(text_input)
185
-
186
- # Display prediction result
187
- st.subheader("Prediction Result:")
188
- st.write(prediction_result)
189
 
190
- if __name__ == "__main__":
191
- main()
 
3
 
4
  # In[5]:
5
 
 
6
  import streamlit as st
7
  from PIL import Image
8
  import torch
9
+ from transformers import BlipProcessor, BlipForQuestionAnswering, BlipImageProcessor
 
 
 
 
 
 
 
 
 
10
 
11
+ # Load model and processors
12
  text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
13
  image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base")
14
+ model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
 
15
 
16
+ # Function to preprocess image
17
  def preprocess_image(image):
18
+ # Resize image to an appropriate size
19
+ image = image.resize((256, 256))
20
+ image_encoding = image_processor(image, return_tensors="pt")
 
 
 
 
21
  return image_encoding["pixel_values"][0]
22
 
23
+ # Function to preprocess text
24
  def preprocess_text(text, max_length=32):
25
+ encoding = text_processor(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
 
 
 
 
 
 
 
 
 
26
  for k, v in encoding.items():
27
  encoding[k] = v.squeeze()
28
  return encoding
29
 
30
+ # Function to make predictions
31
  def predict(image, question):
 
 
 
 
 
 
 
 
 
 
 
 
32
  model.eval()
33
+ with torch.no_grad():
34
+ pixel_values = preprocess_image(image).unsqueeze(0)
35
+ encoding = preprocess_text(question)
36
+ outputs = model(pixel_values=pixel_values, input_ids=encoding['input_ids'].unsqueeze(0))
37
 
38
+ prediction_result = text_processor.decode(outputs[0][0], skip_special_tokens=True)
39
  return prediction_result
40
 
41
+ # Streamlit app
42
  def main():
 
43
  st.set_page_config(
44
  page_title="PathoAgent",
45
  page_icon=":microscope:",
46
  layout="wide"
47
  )
48
 
 
49
  st.title(":microscope: PathoAgent")
50
  st.markdown(
51
  """
 
67
  """,
68
  unsafe_allow_html=True
69
  )
70
+
71
  st.markdown("<div class='header'><h3 class='subheader'>Medical Image Analysis for Pathology</h3></div>", unsafe_allow_html=True)
72
  st.markdown("<hr style='border: 1px solid #ddd;'>", unsafe_allow_html=True)
73
 
 
74
  nav_option = st.sidebar.radio("Navigation", ["Home", "Sample Images", "Upload Image"])
75
 
76
  if nav_option == "Home":
 
80
  elif nav_option == "Upload Image":
81
  upload_image()
82
 
83
+ # Other functions remain unchanged...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ if _name_ == "_main_":
86
+ main()