mediassist / app.py
iamak132003's picture
Updated app.py
7a54aca verified
import streamlit as st
from unsloth import FastLanguageModel
from pydantic import BaseModel, Field, ValidationError
from PIL import Image
import time
import google.generativeai as genai
import os
if torch.cuda.is_available():
device = torch.device("cuda")
major_version, minor_version = torch.cuda.get_device_capability()
else:
device = torch.device("cpu")
major_version, minor_version = 0, 0
st.write("CUDA not available. Using CPU.")
api= os.getenv("GEMINI_API_KEY")
genai.configure(api_key=api)
def custom_css():
st.markdown(
"""
<style>
/* General Page Style */
body {
background-color: #f9fafc;
font-family: 'Roboto', sans-serif;
color: #444;
margin: 0;
padding: 0;
}
/* Title Styling */
.title {
font-size: 3em;
color: #2c3e50;
text-align: center;
font-weight: 700;
margin-bottom: 30px;
letter-spacing: 1px;
text-transform: uppercase;
}
/* Subtitle Styling */
h2 {
font-size: 1.8em;
color: #34495e;
margin-bottom: 15px;
}
/* Inputs */
input, textarea, select {
font-size: 1.1em;
padding: 10px;
border-radius: 8px;
border: 1px solid #ddd;
margin-bottom: 15px;
width: 100%;
box-sizing: border-box;
}
textarea {
min-height: 150px;
}
/* File Uploader Styling */
.stFileUploader label {
font-size: 1.2em;
font-weight: 500;
color: #444;
margin-bottom: 10px;
}
/* Sidebar Styling */
[data-testid="stSidebar"] {
background-color: #2c3e50;
color: #fff;
padding: 20px;
}
[data-testid="stSidebar"] h2 {
color: #ecf0f1;
}
/* Buttons */
.stButton button {
background-color: black;
color: white;
font-size: 1em;
font-weight: bold;
padding: 10px 20px;
border-radius: 6px;
border: none;
cursor: pointer;
transition: background 0.3s ease, border 0.3s ease, color 0.3s ease;
}
.stButton button:hover {
background-color: white;
border: 2px solid black;
color: black;
}
/* Radio Button Styling */
.stRadio label {
font-size: 1.2em;
font-weight: 500;
color: #555;
}
/* Footer */
footer {
font-size: 0.9em;
color: #aaa;
text-align: center;
padding: 20px 0;
margin-top: 30px;
}
/* Transitions for elements */
input, textarea, select, .stButton button {
transition: all 0.3s ease;
}
/* Hover and focus effects */
input:focus, textarea:focus, select:focus {
outline: none;
border-color: #6a82fb;
box-shadow: 0 0 5px rgba(106, 130, 251, 0.5);
}
</style>
""",
unsafe_allow_html=True,
)
class UserInput(BaseModel):
age: int
gender: str
symptoms: str
@st.cache_resource
def load_model():
max_seq_length = 2048
dtype = None
load_in_4bit = True
model_name = "iamak132003/disease_diagnosis"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)
return model, tokenizer
model, tokenizer = load_model()
def save_temp_file(uploaded_file):
current_dir = os.getcwd()
file_path = os.path.join(current_dir, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getvalue())
return file_path
def delete_file(file_name):
try:
if os.path.exists(file_name):
os.remove(file_name)
return f"File '{file_name}' has been deleted."
else:
return f"File '{file_name}' does not exist."
except Exception as e:
return f"An error occurred while deleting the file: {e}"
def home():
st.markdown('<h1 class="title">Medical Assistant App</h1>', unsafe_allow_html=True)
st.write("<h2>Your AI-powered health companion. Choose a feature to begin.</h2>", unsafe_allow_html=True)
col1, col2, col3 = st.columns(3)
with col1:
if st.button("Lab Report Summarizer"):
st.session_state.page = "Lab Report Summarizer"
st.rerun()
with col2:
if st.button("Prescription Generator"):
st.session_state.page = "Prescription Generator"
st.rerun()
with col3:
if st.button("AI Disease Diagnoser"):
st.session_state.page = "AI Disease Diagnoser"
st.rerun()
###########################################################
# GEMINI API #
###########################################################
def generate_summary_from_image(image, prompt):
model = genai.GenerativeModel('models/gemini-1.5-flash')
response = model.generate_content([prompt, image])
return response.text
def lab_report_summarizer():
st.markdown('<h1 class="title">Lab Report Summarizer</h1>', unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload a PDF/Image", type=["pdf", "jpg", "jpeg", "png"])
if uploaded_file:
file_path = save_temp_file(uploaded_file)
if uploaded_file.type in ["image/jpeg", "image/png", "image/jpg"]:
img = Image.open(file_path)
if st.button('Analyze'):
st.markdown("#### Analyzing the uploaded image...")
prompt = "Summarize the key findings in the uploaded lab report image."
response = generate_summary_from_image(img, prompt)
st.markdown(response)
delete_file(file_path)
if st.sidebar.button("Home"):
st.session_state.page = "Home"
st.rerun()
def prescription_exporter():
st.markdown('<h1 class="title">Prescription Analyzer</h1>', unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
if uploaded_file:
file_path = save_temp_file(uploaded_file)
img = Image.open(file_path)
prompt = "Please extract and summarize the prescription details from the image provided."
try:
summary = generate_summary_from_image(img, prompt)
st.markdown(summary)
except Exception as e:
st.error(f"Error generating summary: {e}")
delete_file(file_path)
if st.sidebar.button("Home"):
st.session_state.page = "Home"
st.rerun()
##############################################################
# FINE TUNED MODEL #
##############################################################
def ai_disease_diagnosis_loader():
msg = st.toast('Gathering symptoms...', icon="🔍")
time.sleep(5)
msg.toast('Analyzing data...', icon="🧠")
time.sleep(10)
msg.toast('Diagnosing...', icon="⚡")
time.sleep(5)
msg.toast('Finalizing results...', icon="🔬")
time.sleep(5)
msg.toast('Diagnosis ready!', icon="🩺")
def ai_symptoms_predictor():
st.markdown('<h1 class="title">AI Disease Diagnoser</h1>', unsafe_allow_html=True)
age = st.number_input("Enter Age:", min_value=5, max_value=100, step=1)
gender = st.selectbox("Select Gender:", ["Male", "Female", "Other"])
symptoms = st.text_area("Enter your symptoms")
if st.button("Predict"):
try:
user_input = UserInput(age=age, gender=gender, symptoms=symptoms)
input_data = f"Symptoms: {user_input.symptoms}\nGender: {user_input.gender}\nAge: {user_input.age}"
disease_diagnosis_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Generate a response that appropriately completes the request.
### Instruction:
Provide a diagnosis and all the recommendations based on the patient's symptoms.
### Input:
{}
### Response:
{}
""".format(input_data, "")
inputs = tokenizer([disease_diagnosis_prompt], return_tensors="pt")
inputs = {key: value.to(model.device) for key, value in inputs.items()}
ai_disease_diagnosis_loader()
outputs = model.generate(**inputs, max_new_tokens=800, use_cache=True)
result = tokenizer.batch_decode(outputs)[0]
response_only = (
result.replace("<|begin_of_text|>", "")
.replace("<|end_of_text|>", "")
.strip()
)
if "### Response:" in response_only:
diagnosis = response_only.split("### Response:")[1].strip()
formatted_diagnosis = diagnosis.replace("**", "<strong>").replace("**", "</strong>")
formatted_diagnosis = formatted_diagnosis.replace("<strong><strong>", "<strong>").replace("</strong></strong>", "</strong>")
formatted_diagnosis = formatted_diagnosis.replace("\n", "<br>")
st.markdown(
f"""
<div style="border: 2px solid #4CAF50; padding: 20px; border-radius: 15px; background-color: #f9f9f9;">
{formatted_diagnosis}
</div>
""", unsafe_allow_html=True)
else:
st.markdown("### Unable to process the response.")
except ValidationError as e:
st.error(f"Validation Error: {e}")
if st.sidebar.button("Home"):
st.session_state.page = "Home"
st.rerun()
def main():
custom_css()
if "page" not in st.session_state:
st.session_state.page = "Home"
if st.session_state.page == "Home":
home()
elif st.session_state.page == "Lab Report Summarizer":
lab_report_summarizer()
elif st.session_state.page == "Prescription Generator":
prescription_exporter()
elif st.session_state.page == "AI Disease Diagnoser":
ai_symptoms_predictor()
if __name__ == "__main__":
main()