| import os |
| import torch |
| import streamlit as st |
| from PIL import Image |
| from transformers import AutoModelForImageClassification, AutoImageProcessor |
| from groq import Groq |
|
|
| |
| st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide") |
|
|
| |
| MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE) |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
|
|
| |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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") |
| client = Groq(api_key=GROQ_API_KEY) |
|
|
| |
| def predict_skin_disease(image): |
| image = image.convert("RGB") |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| logits = outputs.logits |
| predicted_class_idx = logits.argmax(-1).item() |
| predicted_label = model.config.id2label[predicted_class_idx] |
| |
| return predicted_label |
|
|
| |
| def get_disease_info(disease_name): |
| prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including description of disease, causes, precausions, risk and treatment options." |
|
|
| chat_completion = client.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| model="llama-3.3-70b-versatile", |
| ) |
|
|
| return chat_completion.choices[0].message.content |
|
|
| |
| def chatbot_response(disease_name, user_query): |
| prompt = f"The detected skin disease is '{disease_name}'. {user_query}" |
| |
| chat_completion = client.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| model="llama-3.3-70b-versatile", |
| ) |
|
|
| return chat_completion.choices[0].message.content |
|
|
| |
| st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200) |
| st.title("🩺 DermaBot - AI Skin Disease Detector") |
| st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.") |
|
|
| |
| uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_image: |
| image = Image.open(uploaded_image) |
| st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
| if st.button("Detect Disease"): |
| with st.spinner("Analyzing..."): |
| disease_name = predict_skin_disease(image) |
| disease_info = get_disease_info(disease_name) |
| st.success(f"**Detected Disease:** {disease_name}") |
| st.write(f"**Details:** {disease_info}") |
|
|
| |
| st.subheader("💬 Ask DermaBot") |
| user_query = st.text_input("Ask about the detected disease:") |
|
|
| if st.button("Ask"): |
| if uploaded_image: |
| with st.spinner("Thinking..."): |
| response = chatbot_response(disease_name, user_query) |
| st.write(response) |
| else: |
| st.warning("Please upload an image first.") |
|
|
| st.markdown("---") |
| st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot") |
|
|
|
|