Spaces:
Sleeping
Sleeping
| import io, json | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| import timm | |
| import requests | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| MODEL_PATH = "src/skin_model.pth" | |
| CLASSES_PATH = "src/classes.json" | |
| TIMM_MODEL_NAME = "efficientnet_b0" | |
| IMG_SIZE = 224 | |
| TOPK = 3 | |
| # Ollama (FREE local LLM). If you don't want LLM, set USE_LLM=False | |
| USE_LLM = True | |
| OLLAMA_URL = "http://localhost:11434/api/generate" | |
| OLLAMA_MODEL = "phi3:mini" # or "mistral:7b", "llama3.1:8b" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------- | |
| # Severity rules (simple demo) | |
| # ----------------------------- | |
| SEVERITY_RULES = { | |
| "tumor_malignant": ("urgent", True), | |
| "bullous": ("urgent", True), | |
| "systemic": ("urgent", True), | |
| "bacterial": ("doctor_soon", True), | |
| "autoimmune": ("doctor_soon", True), | |
| "infestation_bite": ("doctor_soon", True), | |
| "drug_exanthem": ("doctor_soon", True), | |
| "fungal": ("monitor", False), | |
| "viral": ("monitor", False), | |
| "eczema_dermatitis": ("monitor", False), | |
| "psoriasis_lichen": ("monitor", False), | |
| "tumor_benign": ("monitor", False), | |
| "hives": ("monitor", False), | |
| "pigment": ("monitor", False), | |
| "hair_nail": ("monitor", False), | |
| "acne_rosacea": ("self_care", False), | |
| } | |
| def severity_from_label(label: str, symptoms: str): | |
| sev, consult = SEVERITY_RULES.get(label, ("monitor", False)) | |
| s = symptoms.lower() | |
| red_flags = ["fever", "bleeding", "pus", "spreading fast", "severe pain", "difficulty breathing", "black", "rapidly growing"] | |
| if any(k in s for k in red_flags): | |
| sev, consult = "urgent", True | |
| return sev, consult | |
| # ----------------------------- | |
| # Load model + classes (cached) | |
| # ----------------------------- | |
| def load_model_and_classes(): | |
| with open(CLASSES_PATH, "r") as f: | |
| classes = json.load(f) | |
| num_classes = len(classes) | |
| model = timm.create_model(TIMM_MODEL_NAME, pretrained=False, num_classes=num_classes) | |
| state = torch.load(MODEL_PATH, map_location="cpu") | |
| model.load_state_dict(state, strict=True) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model, classes | |
| # EfficientNet preprocessing (same as training) | |
| transform = T.Compose([ | |
| T.Resize((IMG_SIZE, IMG_SIZE)), | |
| T.ToTensor(), | |
| T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)), | |
| ]) | |
| def predict_image(model, pil_img, classes): | |
| x = transform(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE) | |
| logits = model(x) # raw scores | |
| probs = torch.softmax(logits, dim=1).squeeze(0) # convert to probabilities | |
| topk = torch.topk(probs, k=min(TOPK, len(classes))) | |
| results = [] | |
| for idx, score in zip(topk.indices.tolist(), topk.values.tolist()): | |
| results.append({"label": classes[idx], "confidence": float(score)}) | |
| return results | |
| def call_ollama(prompt: str) -> str: | |
| payload = { | |
| "model": OLLAMA_MODEL, | |
| "prompt": prompt, | |
| "stream": False, | |
| "options": {"temperature": 0.3} | |
| } | |
| r = requests.post(OLLAMA_URL, json=payload, timeout=60) | |
| r.raise_for_status() | |
| return r.json().get("response", "").strip() | |
| def build_prompt(symptoms, top3, severity, doctor_consult): | |
| return f""" | |
| You are a health assistant for a university hackathon demo. | |
| Be careful and do NOT diagnose with certainty. | |
| User symptoms: | |
| {symptoms} | |
| Image model top-3 predictions: | |
| {top3} | |
| Severity decision: | |
| severity={severity}, doctor_consult={doctor_consult} | |
| Explain in simple English: | |
| - What top prediction means | |
| - What to do now (safe steps) | |
| - When to see a doctor (based on severity + red flags) | |
| - Ask 2 follow-up questions | |
| Add: "Not medical advice" | |
| """.strip() | |
| # ----------------------------- | |
| # Streamlit UI | |
| # ----------------------------- | |
| st.set_page_config(page_title="Skin Disease Demo", page_icon="🧴", layout="centered") | |
| st.title("🧴 Skin Disease Prediction Demo") | |
| st.write("Upload a skin image + type symptoms text. The model shows **Top-3 predictions** and a simple severity suggestion.") | |
| model, classes = load_model_and_classes() | |
| st.caption(f"Running on **{DEVICE.upper()}** | Model: {TIMM_MODEL_NAME} | Classes: {len(classes)}") | |
| img_file = st.file_uploader("Upload skin image (jpg/png)", type=["jpg", "jpeg", "png"]) | |
| symptoms = st.text_area("Symptoms (example: itchy red patch, burning, spreading, fever?)", height=100) | |
| colA, colB = st.columns(2) | |
| with colA: | |
| use_llm = st.checkbox("Use LLM explanation (Ollama)", value=USE_LLM) | |
| with colB: | |
| st.write("") | |
| if img_file is not None: | |
| pil_img = Image.open(io.BytesIO(img_file.read())) | |
| st.image(pil_img, caption="Uploaded Image", use_container_width=True) | |
| if st.button("Predict"): | |
| top3 = predict_image(model, pil_img, classes) | |
| top1 = top3[0] | |
| severity, doctor_consult = severity_from_label(top1["label"], symptoms) | |
| st.subheader("✅ Prediction") | |
| st.write(f"**Top-1:** `{top1['label']}` — **Confidence:** `{top1['confidence']*100:.2f}%`") | |
| # Confidence bar | |
| st.progress(min(int(top1["confidence"] * 100), 100)) | |
| st.subheader("Top-3 (recommended in demo)") | |
| for i, item in enumerate(top3, start=1): | |
| st.write(f"**{i}.** `{item['label']}` — `{item['confidence']*100:.2f}%`") | |
| st.subheader("⚠️ Severity suggestion (rule-based)") | |
| st.write(f"**Severity:** `{severity}`") | |
| st.write(f"**Doctor consult needed?** `{doctor_consult}`") | |
| st.info("This is a demo/education tool. Not medical advice.") | |
| # LLM explanation | |
| if use_llm: | |
| st.subheader("🧠 LLM Explanation (simple language)") | |
| try: | |
| prompt = build_prompt(symptoms, top3, severity, doctor_consult) | |
| explanation = call_ollama(prompt) | |
| st.write(explanation) | |
| except Exception as e: | |
| st.warning(f"LLM not available. Reason: {e}") | |
| st.write("Tip: Start Ollama + pull a model (phi3:mini).") | |
| else: | |
| st.warning("Upload an image to start.") | |
| uploaded = st.file_uploader( | |
| "Upload skin image", | |
| type=["jpg", "jpeg", "png"], | |
| accept_multiple_files=False | |
| ) | |