import os import torch import streamlit as st from PIL import Image from transformers import AutoModelForImageClassification, AutoImageProcessor from groq import Groq # Set page config st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide") # Load model and processor MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE) processor = AutoImageProcessor.from_pretrained(MODEL_NAME) # Set up the Groq API key (replace with your actual key or use an environment variable) GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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") client = Groq(api_key=GROQ_API_KEY) # Initialize session state for disease_name if "disease_name" not in st.session_state: st.session_state.disease_name = None # Function to predict skin disease def predict_skin_disease(image): image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label # Function to get disease details from Groq API def get_disease_info(disease_name): prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including causes, symptoms, and treatment options." chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.3-70b-versatile", ) return chat_completion.choices[0].message.content # Function to handle chatbot queries def chatbot_response(disease_name, user_query): if not disease_name: return "Please upload an image and detect the disease first." prompt = f"The detected skin disease is '{disease_name}'. {user_query}" chat_completion = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="llama-3.3-70b-versatile", ) return chat_completion.choices[0].message.content # Streamlit UI st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200) st.title("🩺 DermaBot - AI Skin Disease Detector") st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.") # Upload image section uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_image: image = Image.open(uploaded_image) st.image(image, caption="Uploaded Image", use_container_width=True) # Fixed the warning if st.button("Detect Disease"): with st.spinner("Analyzing..."): disease_name = predict_skin_disease(image) st.session_state.disease_name = disease_name # Store in session state disease_info = get_disease_info(disease_name) st.success(f"**Detected Disease:** {disease_name}") st.write(f"**Details:** {disease_info}") # Chatbot section st.subheader("💬 Ask DermaBot") user_query = st.text_input("Ask about the detected disease:") if st.button("Ask"): with st.spinner("Thinking..."): response = chatbot_response(st.session_state.disease_name, user_query) st.write(response) st.markdown("---") st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot")