File size: 5,987 Bytes
9d704c1
 
 
 
 
82cfbbf
9d704c1
 
 
9d4621a
82cfbbf
 
 
 
 
 
 
 
 
 
9d704c1
 
 
82cfbbf
 
9d704c1
 
82cfbbf
 
 
 
 
 
 
9d704c1
 
 
82cfbbf
 
 
 
 
 
 
 
 
 
9d704c1
 
 
 
 
82cfbbf
 
 
 
 
 
 
 
 
 
 
 
9d704c1
82cfbbf
 
 
9d704c1
 
 
 
82cfbbf
9d704c1
 
 
 
 
 
82cfbbf
9d704c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82cfbbf
9d704c1
 
 
82cfbbf
9d704c1
 
 
 
 
 
 
 
 
82cfbbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d704c1
82cfbbf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python
# coding: utf-8

# In[5]:


import streamlit as st
from PIL import Image
import torch
import torch
import requests
from transformers import BlipProcessor, BlipForQuestionAnswering,BlipImageProcessor, AutoProcessor
from transformers import BlipConfig
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

text_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
image_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained(r"blip_model_v2_epo89" )


def preprocess_image(image):
    # Your image preprocessing logic here...
    # Example: Resize image to 128x128 pixels
    image = image.resize((128, 128))
    image_encoding = image_processor(image,
                                     do_resize=True,
                                     size=(128, 128),
                                     return_tensors="pt")
    return image_encoding["pixel_values"][0]

def preprocess_text(text, max_length=32):
    # Your text preprocessing logic here...
    encoding = text_processor(
        None,
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    
    for k, v in encoding.items():
        encoding[k] = v.squeeze()
    return encoding

def predict(image, question):
    # Preprocess image
    pixel_values = preprocess_image(image).unsqueeze(0)

    # Preprocess text
    encoding = preprocess_text(question)

    # Print shapes for debugging
    #print("Pixel Values Shape:", pixel_values.shape)
    #print("Input IDs Shape:", encoding['input_ids'].unsqueeze(0).shape)

    # Perform prediction using your model
    # Example: Replace this with your actual prediction logic
    model.eval()
    outputs = model.generate(pixel_values=pixel_values, input_ids=encoding['input_ids'].unsqueeze(0))

    prediction_result = text_processor.decode(outputs[0], skip_special_tokens=True)

    return prediction_result

def main():
    # Set page title and configure page layout
    st.set_page_config(
        page_title="PathoAgent",
        page_icon=":microscope:",
        layout="wide"
    )

    # Add header with styled text
    st.title(":microscope: PathoAgent")
    st.markdown(
        """
        <style>
            body {
                background-color: #f1f1f1;
            }
            .header {
                text-align: center;
                padding: 20px;
                background-color: #3498db;
            }
            .subheader {
                color: #fff;
                text-align: center;
                padding-bottom: 20px;
            }
        </style>
        """,
        unsafe_allow_html=True
    )
    
    st.markdown("<div class='header'><h3 class='subheader'>Medical Image Analysis for Pathology</h3></div>", unsafe_allow_html=True)
    st.markdown("<hr style='border: 1px solid #ddd;'>", unsafe_allow_html=True)

    # Navigation bar
    nav_option = st.sidebar.radio("Navigation", ["Home", "Sample Images", "Upload Image"])
    
    if nav_option == "Home":
        home()
    elif nav_option == "Sample Images":
        sample_images()
    elif nav_option == "Upload Image":
        upload_image()

def home():
    st.header("Welcome to PathoAgent!")
    st.write(
        "PathoAgent is an AI-powered medical image analysis tool designed for pathology diagnostics. "
        "It empowers healthcare professionals with accurate predictions and insights from medical images. "
        "Choose an option from the sidebar to get started."
    )
    
    st.header("About PathoAgent")
    st.write(
        "PathoAgent leverages advanced VQA algorithms to analyze medical images related to pathology. "
        "Whether you want to upload your own images or use our sample images, PathoAgent provides predictions for pathology-related questions. "
        "Explore the features and capabilities to enhance your diagnostic process."
    )

def sample_images():
    st.header("Sample Images")

    # Sample images
    example_image = {
        "Sample 1": "img_0002.jpg",
    }

    # Button to load sample images
    if st.button("Load Example Images"):
        
        sample_image = Image.open(example_image).convert('RGB')
        st.image(sample_image, caption=f"Example Image", use_column_width=True)

        # Text input for each sample image
        text_input = st.text_area(f"Input Question:")

        # Predict button for each sample image
        if st.button(f"Predict"):
            if text_input:
                # Perform prediction
                prediction_result = predict(sample_image, text_input)

                # Display input text
                st.subheader(f"Input Question:")
                st.write(text_input)

                # Display prediction result
                st.subheader(f"Prediction Result:")
                st.write(prediction_result)

def upload_image():
    st.header("Upload Image")

    # Image upload
    uploaded_file = st.file_uploader("Choose a file", type=["jpg", "png", "jpeg"])

    # Text input
    st.subheader("Input Question")
    text_input = st.text_area("Enter text here:")

    # Display uploaded image
    if uploaded_file is not None:
        image = Image.open(uploaded_file).convert('RGB')
        st.image(image, caption="Uploaded Image.", use_column_width=True)

    # Predict button
    if st.button("Predict"):
        if uploaded_file is not None and text_input:
            # Perform prediction
            prediction_result = predict(image, text_input)

            # Display input text
            st.subheader("Input Question:")
            st.write(text_input)

            # Display prediction result
            st.subheader("Prediction Result:")
            st.write(prediction_result)

if __name__ == "__main__":
    main()