Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| from rag_model import * | |
| from yolo_model import * | |
| def load_image_model(): | |
| return pipeline("image-classification", model="Heem2/wound-image-classification") | |
| pipeline = load_image_model() | |
| yolo_model = load_yolo_model() | |
| # Add custom CSS | |
| css = """ | |
| <style> | |
| body { | |
| font-family: 'Arial', sans-serif; | |
| background-color: #f5f5f5; | |
| } | |
| .main { | |
| background-color: #ffffff; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
| } | |
| .stButton button { | |
| background-color: #4CAF50; | |
| color: white; | |
| border: none; | |
| padding: 10px 20px; | |
| text-align: center; | |
| text-decoration: none; | |
| display: inline-block; | |
| font-size: 16px; | |
| margin: 4px 2px; | |
| cursor: pointer; | |
| border-radius: 5px; | |
| } | |
| .stButton button:hover { | |
| background-color: #45a049; | |
| } | |
| .stApp > header { | |
| background-color: transparent; | |
| } | |
| .stApp { | |
| margin: auto; | |
| background-color: #D9AFD9; | |
| background-image: linear-gradient(0deg, #D9AFD9 0%, #97D9E1 100%); | |
| } | |
| [data-testid='stFileUploader'] { | |
| width: max-content; | |
| } | |
| [data-testid='stFileUploader'] section { | |
| padding: 0; | |
| float: left; | |
| } | |
| [data-testid='stFileUploader'] section > input + div { | |
| display: none; | |
| } | |
| [data-testid='stFileUploader'] section + div { | |
| float: right; | |
| padding-top: 0; | |
| } | |
| </style> | |
| """ | |
| st.markdown(css, unsafe_allow_html=True) | |
| st.title("**FirstAid-AI**") | |
| # Add a description at the top | |
| st.markdown(""" | |
| ### Welcome to FirstAid-AI | |
| This application provides medical advice based on images of wounds and medical equipment. | |
| Upload an image of your wound or medical equipment, and the AI will classify the image and provide relevant advice. | |
| """) | |
| st.markdown("## How to Use FirstAid-AI") | |
| st.markdown("### 1. Upload an image of a wound and a piece of equipment (if applicable)") | |
| st.image("images/example3.png", use_container_width=True) | |
| st.caption("The AI model will detect the wound or equipment in the image and provide confidence levels. The AI assistant will then provide treatment or usage advice.") | |
| st.markdown("### 2. Ask follow-up questions and continue the conversation with the AI assistant!") | |
| # Initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Dropdown to select the type of images to provide | |
| option = st.selectbox( | |
| "Select the type of images you want to provide:", | |
| ("Provide just wound image", "Provide both wound and equipment") | |
| ) | |
| # Upload images based on the selected option | |
| file_wound = None | |
| file_equipment = None | |
| if option == "Provide just wound image": | |
| file_wound = st.file_uploader("Upload an image of your wound") | |
| elif option == "Provide both wound and equipment": | |
| file_wound = st.file_uploader("Upload an image of your wound") | |
| file_equipment = st.file_uploader("Upload an image of your equipment") | |
| # Reset chat history if no file is uploaded | |
| if file_wound is None and file_equipment is None: | |
| st.session_state.messages = [] | |
| if file_wound is not None and option == "Provide just wound image": | |
| # Display the wound image and predictions | |
| col1, col2 = st.columns(2) | |
| image = Image.open(file_wound) | |
| col1.image(image, use_container_width=True) | |
| # Classify the wound image | |
| predictions = pipeline(image) | |
| detected_wound = predictions[0]['label'] | |
| col2.header("Detected Wound") | |
| for p in predictions: | |
| col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%") | |
| # Initial advice for wound | |
| if not st.session_state.messages: | |
| initial_query = f"Provide treatment advice for a {detected_wound} wound" | |
| initial_response = rag_chain.invoke(initial_query) | |
| st.session_state.messages.append({"role": "assistant", "content": initial_response}) | |
| # Display chat messages from history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Accept user input if an image is uploaded | |
| if (file_wound is not None or file_equipment is not None) and (prompt := st.chat_input("Ask a follow-up question or continue the conversation:")): | |
| # Display user message in chat | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Prepare the conversation history for rag_chain | |
| conversation_history = "\n".join( | |
| f"{message['role']}: {message['content']}" for message in st.session_state.messages | |
| ) | |
| # Generate response from rag_chain | |
| query = f"Context:\n{conversation_history}\n\nAssistant, respond to the user's latest query: {prompt}" | |
| response = rag_chain.invoke(query) | |
| # Display assistant response in chat message container | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| if file_wound is not None and file_equipment is not None and option == "Provide both wound and equipment": | |
| # Display the wound image and predictions | |
| col1, col2 = st.columns(2) | |
| image = Image.open(file_wound) | |
| col1.image(image, use_container_width=True) | |
| # Classify the wound image | |
| predictions = pipeline(image) | |
| detected_wound = predictions[0]['label'] | |
| col2.header("Detected Wound") | |
| for p in predictions: | |
| col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%") | |
| # Display the equipment image and predictions | |
| col3, col4 = st.columns(2) | |
| image = Image.open(file_equipment) | |
| col3.image(image, use_container_width=True) | |
| # Convert the image to a format supported by YOLO | |
| image_np = np.array(image) | |
| image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| # Classify the equipment image using YOLO model | |
| detected_equipment = get_detected_objects(yolo_model, image_cv) | |
| col4.header("Detected Equipment") | |
| col4.subheader(detected_equipment) | |
| # Initial advice for equipment | |
| if not st.session_state.messages: | |
| initial_query = f"Provide usage advice for {detected_equipment} when treating a {detected_wound} wound" | |
| initial_response = rag_chain.invoke(initial_query) | |
| st.session_state.messages.append({"role": "assistant", "content": initial_response}) | |
| # Display chat messages from history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Accept user input if an image is uploaded | |
| if (file_wound is not None or file_equipment is not None) and (prompt := st.chat_input("Ask a follow-up question or continue the conversation:")): | |
| # Display user message in chat | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Prepare the conversation history for rag_chain | |
| conversation_history = "\n".join( | |
| f"{message['role']}: {message['content']}" for message in st.session_state.messages | |
| ) | |
| # Generate response from rag_chain | |
| query = f"Context:\n{conversation_history}\n\nAssistant, respond to the user's latest query: {prompt}" | |
| response = rag_chain.invoke(query) | |
| # Display assistant response in chat message container | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) |