skin_model / src /streamlit_app.py
muruga778's picture
Update src/streamlit_app.py
0dedc16 verified
import io, json
import numpy as np
import streamlit as st
import torch
from PIL import Image
import torchvision.transforms as T
import timm
import requests
# -----------------------------
# CONFIG
# -----------------------------
MODEL_PATH = "src/skin_model.pth"
CLASSES_PATH = "src/classes.json"
TIMM_MODEL_NAME = "efficientnet_b0"
IMG_SIZE = 224
TOPK = 3
# Ollama (FREE local LLM). If you don't want LLM, set USE_LLM=False
USE_LLM = True
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "phi3:mini" # or "mistral:7b", "llama3.1:8b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Severity rules (simple demo)
# -----------------------------
SEVERITY_RULES = {
"tumor_malignant": ("urgent", True),
"bullous": ("urgent", True),
"systemic": ("urgent", True),
"bacterial": ("doctor_soon", True),
"autoimmune": ("doctor_soon", True),
"infestation_bite": ("doctor_soon", True),
"drug_exanthem": ("doctor_soon", True),
"fungal": ("monitor", False),
"viral": ("monitor", False),
"eczema_dermatitis": ("monitor", False),
"psoriasis_lichen": ("monitor", False),
"tumor_benign": ("monitor", False),
"hives": ("monitor", False),
"pigment": ("monitor", False),
"hair_nail": ("monitor", False),
"acne_rosacea": ("self_care", False),
}
def severity_from_label(label: str, symptoms: str):
sev, consult = SEVERITY_RULES.get(label, ("monitor", False))
s = symptoms.lower()
red_flags = ["fever", "bleeding", "pus", "spreading fast", "severe pain", "difficulty breathing", "black", "rapidly growing"]
if any(k in s for k in red_flags):
sev, consult = "urgent", True
return sev, consult
# -----------------------------
# Load model + classes (cached)
# -----------------------------
@st.cache_resource
def load_model_and_classes():
with open(CLASSES_PATH, "r") as f:
classes = json.load(f)
num_classes = len(classes)
model = timm.create_model(TIMM_MODEL_NAME, pretrained=False, num_classes=num_classes)
state = torch.load(MODEL_PATH, map_location="cpu")
model.load_state_dict(state, strict=True)
model.to(DEVICE)
model.eval()
return model, classes
# EfficientNet preprocessing (same as training)
transform = T.Compose([
T.Resize((IMG_SIZE, IMG_SIZE)),
T.ToTensor(),
T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
])
@torch.no_grad()
def predict_image(model, pil_img, classes):
x = transform(pil_img.convert("RGB")).unsqueeze(0).to(DEVICE)
logits = model(x) # raw scores
probs = torch.softmax(logits, dim=1).squeeze(0) # convert to probabilities
topk = torch.topk(probs, k=min(TOPK, len(classes)))
results = []
for idx, score in zip(topk.indices.tolist(), topk.values.tolist()):
results.append({"label": classes[idx], "confidence": float(score)})
return results
def call_ollama(prompt: str) -> str:
payload = {
"model": OLLAMA_MODEL,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.3}
}
r = requests.post(OLLAMA_URL, json=payload, timeout=60)
r.raise_for_status()
return r.json().get("response", "").strip()
def build_prompt(symptoms, top3, severity, doctor_consult):
return f"""
You are a health assistant for a university hackathon demo.
Be careful and do NOT diagnose with certainty.
User symptoms:
{symptoms}
Image model top-3 predictions:
{top3}
Severity decision:
severity={severity}, doctor_consult={doctor_consult}
Explain in simple English:
- What top prediction means
- What to do now (safe steps)
- When to see a doctor (based on severity + red flags)
- Ask 2 follow-up questions
Add: "Not medical advice"
""".strip()
# -----------------------------
# Streamlit UI
# -----------------------------
st.set_page_config(page_title="Skin Disease Demo", page_icon="🧴", layout="centered")
st.title("🧴 Skin Disease Prediction Demo")
st.write("Upload a skin image + type symptoms text. The model shows **Top-3 predictions** and a simple severity suggestion.")
model, classes = load_model_and_classes()
st.caption(f"Running on **{DEVICE.upper()}** | Model: {TIMM_MODEL_NAME} | Classes: {len(classes)}")
img_file = st.file_uploader("Upload skin image (jpg/png)", type=["jpg", "jpeg", "png"])
symptoms = st.text_area("Symptoms (example: itchy red patch, burning, spreading, fever?)", height=100)
colA, colB = st.columns(2)
with colA:
use_llm = st.checkbox("Use LLM explanation (Ollama)", value=USE_LLM)
with colB:
st.write("")
if img_file is not None:
pil_img = Image.open(io.BytesIO(img_file.read()))
st.image(pil_img, caption="Uploaded Image", use_container_width=True)
if st.button("Predict"):
top3 = predict_image(model, pil_img, classes)
top1 = top3[0]
severity, doctor_consult = severity_from_label(top1["label"], symptoms)
st.subheader("✅ Prediction")
st.write(f"**Top-1:** `{top1['label']}` — **Confidence:** `{top1['confidence']*100:.2f}%`")
# Confidence bar
st.progress(min(int(top1["confidence"] * 100), 100))
st.subheader("Top-3 (recommended in demo)")
for i, item in enumerate(top3, start=1):
st.write(f"**{i}.** `{item['label']}` — `{item['confidence']*100:.2f}%`")
st.subheader("⚠️ Severity suggestion (rule-based)")
st.write(f"**Severity:** `{severity}`")
st.write(f"**Doctor consult needed?** `{doctor_consult}`")
st.info("This is a demo/education tool. Not medical advice.")
# LLM explanation
if use_llm:
st.subheader("🧠 LLM Explanation (simple language)")
try:
prompt = build_prompt(symptoms, top3, severity, doctor_consult)
explanation = call_ollama(prompt)
st.write(explanation)
except Exception as e:
st.warning(f"LLM not available. Reason: {e}")
st.write("Tip: Start Ollama + pull a model (phi3:mini).")
else:
st.warning("Upload an image to start.")
uploaded = st.file_uploader(
"Upload skin image",
type=["jpg", "jpeg", "png"],
accept_multiple_files=False
)