Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import streamlit as st | |
| from PIL import Image | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| from groq import Groq | |
| # Set page config | |
| st.set_page_config(page_title="Detection - AI Skin Disease Detector", page_icon="🩺", layout="wide") | |
| # Load model and processor | |
| MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE) | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| # Set up the Groq API key (replace with your actual key or use an environment variable) | |
| client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
| # Initialize session state for disease details | |
| if "disease_name" not in st.session_state: | |
| st.session_state.disease_name = None | |
| if "disease_info" not in st.session_state: | |
| st.session_state.disease_info = None | |
| # Function to predict skin disease | |
| def predict_skin_disease(image): | |
| image = image.convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| predicted_label = model.config.id2label[predicted_class_idx] | |
| return predicted_label | |
| # Function to get disease details from Groq API | |
| def get_disease_info(disease_name): | |
| prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including description of disease, causes, precausions, risk and treatment options." | |
| chat_completion = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama-3.3-70b-versatile", | |
| ) | |
| return chat_completion.choices[0].message.content | |
| # Function to handle chatbot queries | |
| def chatbot_response(disease_name, user_query): | |
| if not disease_name: | |
| return "Please upload an image and detect the disease first." | |
| prompt = f"The detected skin disease is '{disease_name}'. {user_query}" | |
| chat_completion = client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama-3.3-70b-versatile", | |
| ) | |
| return chat_completion.choices[0].message.content | |
| # Streamlit UI | |
| st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200) | |
| st.title("🩺 DermaBot - AI Skin Disease Detector") | |
| st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.") | |
| # Step 1: Upload image | |
| uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_image: | |
| image = Image.open(uploaded_image) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| # Step 2: Detect disease | |
| if st.button("Detect Disease"): | |
| with st.spinner("Analyzing..."): | |
| disease_name = predict_skin_disease(image) | |
| disease_info = get_disease_info(disease_name) | |
| # Store results in session state | |
| st.session_state.disease_name = disease_name | |
| st.session_state.disease_info = disease_info | |
| # Display detected disease information if available | |
| if st.session_state.disease_name: | |
| st.success(f"**Detected Disease:** {st.session_state.disease_name}") | |
| st.write(f"**Details:** {st.session_state.disease_info}") | |
| # Step 3: Chatbot | |
| st.subheader("💬 Ask DermaBot about this disease") | |
| user_query = st.text_input("Ask about the detected disease:") | |
| if st.button("Ask"): | |
| with st.spinner("Thinking..."): | |
| response = chatbot_response(st.session_state.disease_name, user_query) | |
| st.write(response) | |
| st.markdown("---") | |
| st.write("🔍 Develop by **HS Tech PVT Ltd** | © 2025 Disease Detection") |