import io, json import numpy as np import streamlit as st import torch from PIL import Image import torchvision.transforms as T import timm import requests # ----------------------------- # CONFIG # ----------------------------- MODEL_PATH = "src/skin_model.pth" CLASSES_PATH = "src/classes.json" TIMM_MODEL_NAME = "efficientnet_b0" IMG_SIZE = 224 TOPK = 3 # Ollama (FREE local LLM). If you don't want LLM, set USE_LLM=False USE_LLM = True OLLAMA_URL = "http://localhost:11434/api/generate" OLLAMA_MODEL = "phi3:mini" # or "mistral:7b", "llama3.1:8b" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # Severity rules (simple demo) # ----------------------------- SEVERITY_RULES = { "tumor_malignant": ("urgent", True), "bullous": ("urgent", True), "systemic": ("urgent", True), "bacterial": ("doctor_soon", True), "autoimmune": ("doctor_soon", True), "infestation_bite": ("doctor_soon", True), "drug_exanthem": ("doctor_soon", True), "fungal": ("monitor", False), "viral": ("monitor", False), "eczema_dermatitis": ("monitor", False), "psoriasis_lichen": ("monitor", False), "tumor_benign": ("monitor", False), "hives": ("monitor", False), "pigment": ("monitor", False), "hair_nail": ("monitor", False), "acne_rosacea": ("self_care", False), } def severity_from_label(label: str, symptoms: str): sev, consult = SEVERITY_RULES.get(label, ("monitor", False)) s = symptoms.lower() red_flags = ["fever", "bleeding", "pus", "spreading fast", "severe pain", "difficulty breathing", "black", "rapidly growing"] if any(k in s for k in red_flags): sev, consult = "urgent", True return sev, consult # ----------------------------- # Load model + classes (cached) # ----------------------------- @st.cache_resource def load_model_and_classes(): with open(CLASSES_PATH, "r") as f: classes = json.load(f) num_classes = len(classes) model = timm.create_model(TIMM_MODEL_NAME, pretrained=False, num_classes=num_classes) state = torch.load(MODEL_PATH, map_location="cpu") model.load_state_dict(state, strict=True) model.to(DEVICE) model.eval() return model, classes # EfficientNet preprocessing (same as training) transform = T.Compose([ T.Resize((IMG_SIZE, IMG_SIZE)), T.ToTensor(), T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)), ]) @torch.no_grad() def predict_image(model, pil_img, classes): x = transform(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) logits = model(x) # raw scores probs = torch.softmax(logits, dim=1).squeeze(0) # convert to probabilities topk = torch.topk(probs, k=min(TOPK, len(classes))) results = [] for idx, score in zip(topk.indices.tolist(), topk.values.tolist()): results.append({"label": classes[idx], "confidence": float(score)}) return results def call_ollama(prompt: str) -> str: payload = { "model": OLLAMA_MODEL, "prompt": prompt, "stream": False, "options": {"temperature": 0.3} } r = requests.post(OLLAMA_URL, json=payload, timeout=60) r.raise_for_status() return r.json().get("response", "").strip() def build_prompt(symptoms, top3, severity, doctor_consult): return f""" You are a health assistant for a university hackathon demo. Be careful and do NOT diagnose with certainty. User symptoms: {symptoms} Image model top-3 predictions: {top3} Severity decision: severity={severity}, doctor_consult={doctor_consult} Explain in simple English: - What top prediction means - What to do now (safe steps) - When to see a doctor (based on severity + red flags) - Ask 2 follow-up questions Add: "Not medical advice" """.strip() # ----------------------------- # Streamlit UI # ----------------------------- st.set_page_config(page_title="Skin Disease Demo", page_icon="🧴", layout="centered") st.title("🧴 Skin Disease Prediction Demo") st.write("Upload a skin image + type symptoms text. The model shows **Top-3 predictions** and a simple severity suggestion.") model, classes = load_model_and_classes() st.caption(f"Running on **{DEVICE.upper()}** | Model: {TIMM_MODEL_NAME} | Classes: {len(classes)}") img_file = st.file_uploader("Upload skin image (jpg/png)", type=["jpg", "jpeg", "png"]) symptoms = st.text_area("Symptoms (example: itchy red patch, burning, spreading, fever?)", height=100) colA, colB = st.columns(2) with colA: use_llm = st.checkbox("Use LLM explanation (Ollama)", value=USE_LLM) with colB: st.write("") if img_file is not None: pil_img = Image.open(io.BytesIO(img_file.read())) st.image(pil_img, caption="Uploaded Image", use_container_width=True) if st.button("Predict"): top3 = predict_image(model, pil_img, classes) top1 = top3[0] severity, doctor_consult = severity_from_label(top1["label"], symptoms) st.subheader("✅ Prediction") st.write(f"**Top-1:** `{top1['label']}` — **Confidence:** `{top1['confidence']*100:.2f}%`") # Confidence bar st.progress(min(int(top1["confidence"] * 100), 100)) st.subheader("Top-3 (recommended in demo)") for i, item in enumerate(top3, start=1): st.write(f"**{i}.** `{item['label']}` — `{item['confidence']*100:.2f}%`") st.subheader("⚠️ Severity suggestion (rule-based)") st.write(f"**Severity:** `{severity}`") st.write(f"**Doctor consult needed?** `{doctor_consult}`") st.info("This is a demo/education tool. Not medical advice.") # LLM explanation if use_llm: st.subheader("🧠 LLM Explanation (simple language)") try: prompt = build_prompt(symptoms, top3, severity, doctor_consult) explanation = call_ollama(prompt) st.write(explanation) except Exception as e: st.warning(f"LLM not available. Reason: {e}") st.write("Tip: Start Ollama + pull a model (phi3:mini).") else: st.warning("Upload an image to start.") uploaded = st.file_uploader( "Upload skin image", type=["jpg", "jpeg", "png"], accept_multiple_files=False )