Spaces:
Sleeping
Sleeping
File size: 6,246 Bytes
e72ed27 19327a5 e72ed27 19327a5 e72ed27 881b347 e72ed27 0dedc16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import io, json
import numpy as np
import streamlit as st
import torch
from PIL import Image
import torchvision.transforms as T
import timm
import requests
# -----------------------------
# CONFIG
# -----------------------------
MODEL_PATH = "src/skin_model.pth"
CLASSES_PATH = "src/classes.json"
TIMM_MODEL_NAME = "efficientnet_b0"
IMG_SIZE = 224
TOPK = 3
# Ollama (FREE local LLM). If you don't want LLM, set USE_LLM=False
USE_LLM = True
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "phi3:mini" # or "mistral:7b", "llama3.1:8b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Severity rules (simple demo)
# -----------------------------
SEVERITY_RULES = {
"tumor_malignant": ("urgent", True),
"bullous": ("urgent", True),
"systemic": ("urgent", True),
"bacterial": ("doctor_soon", True),
"autoimmune": ("doctor_soon", True),
"infestation_bite": ("doctor_soon", True),
"drug_exanthem": ("doctor_soon", True),
"fungal": ("monitor", False),
"viral": ("monitor", False),
"eczema_dermatitis": ("monitor", False),
"psoriasis_lichen": ("monitor", False),
"tumor_benign": ("monitor", False),
"hives": ("monitor", False),
"pigment": ("monitor", False),
"hair_nail": ("monitor", False),
"acne_rosacea": ("self_care", False),
}
def severity_from_label(label: str, symptoms: str):
sev, consult = SEVERITY_RULES.get(label, ("monitor", False))
s = symptoms.lower()
red_flags = ["fever", "bleeding", "pus", "spreading fast", "severe pain", "difficulty breathing", "black", "rapidly growing"]
if any(k in s for k in red_flags):
sev, consult = "urgent", True
return sev, consult
# -----------------------------
# Load model + classes (cached)
# -----------------------------
@st.cache_resource
def load_model_and_classes():
with open(CLASSES_PATH, "r") as f:
classes = json.load(f)
num_classes = len(classes)
model = timm.create_model(TIMM_MODEL_NAME, pretrained=False, num_classes=num_classes)
state = torch.load(MODEL_PATH, map_location="cpu")
model.load_state_dict(state, strict=True)
model.to(DEVICE)
model.eval()
return model, classes
# EfficientNet preprocessing (same as training)
transform = T.Compose([
T.Resize((IMG_SIZE, IMG_SIZE)),
T.ToTensor(),
T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])
@torch.no_grad()
def predict_image(model, pil_img, classes):
x = transform(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)
logits = model(x) # raw scores
probs = torch.softmax(logits, dim=1).squeeze(0) # convert to probabilities
topk = torch.topk(probs, k=min(TOPK, len(classes)))
results = []
for idx, score in zip(topk.indices.tolist(), topk.values.tolist()):
results.append({"label": classes[idx], "confidence": float(score)})
return results
def call_ollama(prompt: str) -> str:
payload = {
"model": OLLAMA_MODEL,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.3}
}
r = requests.post(OLLAMA_URL, json=payload, timeout=60)
r.raise_for_status()
return r.json().get("response", "").strip()
def build_prompt(symptoms, top3, severity, doctor_consult):
return f"""
You are a health assistant for a university hackathon demo.
Be careful and do NOT diagnose with certainty.
User symptoms:
{symptoms}
Image model top-3 predictions:
{top3}
Severity decision:
severity={severity}, doctor_consult={doctor_consult}
Explain in simple English:
- What top prediction means
- What to do now (safe steps)
- When to see a doctor (based on severity + red flags)
- Ask 2 follow-up questions
Add: "Not medical advice"
""".strip()
# -----------------------------
# Streamlit UI
# -----------------------------
st.set_page_config(page_title="Skin Disease Demo", page_icon="🧴", layout="centered")
st.title("🧴 Skin Disease Prediction Demo")
st.write("Upload a skin image + type symptoms text. The model shows **Top-3 predictions** and a simple severity suggestion.")
model, classes = load_model_and_classes()
st.caption(f"Running on **{DEVICE.upper()}** | Model: {TIMM_MODEL_NAME} | Classes: {len(classes)}")
img_file = st.file_uploader("Upload skin image (jpg/png)", type=["jpg", "jpeg", "png"])
symptoms = st.text_area("Symptoms (example: itchy red patch, burning, spreading, fever?)", height=100)
colA, colB = st.columns(2)
with colA:
use_llm = st.checkbox("Use LLM explanation (Ollama)", value=USE_LLM)
with colB:
st.write("")
if img_file is not None:
pil_img = Image.open(io.BytesIO(img_file.read()))
st.image(pil_img, caption="Uploaded Image", use_container_width=True)
if st.button("Predict"):
top3 = predict_image(model, pil_img, classes)
top1 = top3[0]
severity, doctor_consult = severity_from_label(top1["label"], symptoms)
st.subheader("✅ Prediction")
st.write(f"**Top-1:** `{top1['label']}` — **Confidence:** `{top1['confidence']*100:.2f}%`")
# Confidence bar
st.progress(min(int(top1["confidence"] * 100), 100))
st.subheader("Top-3 (recommended in demo)")
for i, item in enumerate(top3, start=1):
st.write(f"**{i}.** `{item['label']}` — `{item['confidence']*100:.2f}%`")
st.subheader("⚠️ Severity suggestion (rule-based)")
st.write(f"**Severity:** `{severity}`")
st.write(f"**Doctor consult needed?** `{doctor_consult}`")
st.info("This is a demo/education tool. Not medical advice.")
# LLM explanation
if use_llm:
st.subheader("🧠 LLM Explanation (simple language)")
try:
prompt = build_prompt(symptoms, top3, severity, doctor_consult)
explanation = call_ollama(prompt)
st.write(explanation)
except Exception as e:
st.warning(f"LLM not available. Reason: {e}")
st.write("Tip: Start Ollama + pull a model (phi3:mini).")
else:
st.warning("Upload an image to start.")
uploaded = st.file_uploader(
"Upload skin image",
type=["jpg", "jpeg", "png"],
accept_multiple_files=False
)
|