DermaBot / app.py
amasood's picture
Update app.py
afe00a8 verified
raw
history blame
3.51 kB
import os
import torch
import streamlit as st
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
from groq import Groq
# Set page config
st.set_page_config(page_title="DermaBot - AI Skin Disease Detector", page_icon="🩺", layout="wide")
# Load model and processor
MODEL_NAME = "Jayanth2002/dinov2-base-finetuned-SkinDisease"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(DEVICE)
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
# Set up the Groq API key (replace with your actual key or use an environment variable)
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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")
client = Groq(api_key=GROQ_API_KEY)
# Initialize session state for disease_name
if "disease_name" not in st.session_state:
st.session_state.disease_name = None
# Function to predict skin disease
def predict_skin_disease(image):
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
return predicted_label
# Function to get disease details from Groq API
def get_disease_info(disease_name):
prompt = f"Provide a detailed explanation about the skin disease '{disease_name}', including causes, symptoms, and treatment options."
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
)
return chat_completion.choices[0].message.content
# Function to handle chatbot queries
def chatbot_response(disease_name, user_query):
if not disease_name:
return "Please upload an image and detect the disease first."
prompt = f"The detected skin disease is '{disease_name}'. {user_query}"
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile",
)
return chat_completion.choices[0].message.content
# Streamlit UI
st.image("https://huggingface.co/spaces/your-huggingface-space/logo.png", width=200)
st.title("🩺 DermaBot - AI Skin Disease Detector")
st.write("Upload an image of a skin condition to get a diagnosis and ask questions about it.")
# Upload image section
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_container_width=True) # Fixed the warning
if st.button("Detect Disease"):
with st.spinner("Analyzing..."):
disease_name = predict_skin_disease(image)
st.session_state.disease_name = disease_name # Store in session state
disease_info = get_disease_info(disease_name)
st.success(f"**Detected Disease:** {disease_name}")
st.write(f"**Details:** {disease_info}")
# Chatbot section
st.subheader("💬 Ask DermaBot")
user_query = st.text_input("Ask about the detected disease:")
if st.button("Ask"):
with st.spinner("Thinking..."):
response = chatbot_response(st.session_state.disease_name, user_query)
st.write(response)
st.markdown("---")
st.write("🔍 Powered by **AI & Groq API** | © 2025 DermaBot")