xmasdaday's picture
Update app_lite.py
9ce66e5 verified
import gradio as gr
import numpy as np
from PIL import Image
import os
# your existing backend (loads ResNet50 + SVM)
from model_resnet50 import predict
# ---------- Labels & inference ----------
LABELS_STR = {"G1": "Cyst – G1", "G2": "Cyst – G2", "G3": "Cyst – G3"}
LABELS_INT = {0: "Cyst – G1", 1: "Cyst – G2", 2: "Cyst – G3"}
def _pretty(lbl):
if isinstance(lbl, str) and lbl in LABELS_STR:
return LABELS_STR[lbl]
if isinstance(lbl, (int, np.integer)) and int(lbl) in LABELS_INT:
return LABELS_INT[int(lbl)]
return str(lbl)
def infer(image: Image.Image):
out = predict(image)
name = _pretty(out["label"])
probs = out.get("probs")
conf = None
if isinstance(probs, dict):
conf = {_pretty(k): float(v) for k, v in probs.items()}
text = f"**Prediction**: {name}"
return text, conf
# ---------- Mobile-first CSS ----------
css = """
footer{visibility:hidden;}
/* narrow by default for phones; expand on desktop */
.gradio-container{max-width: 460px !important; margin:auto;}
@media (min-width: 900px){ .gradio-container{max-width: 880px !important;} }
/* simple masthead */
.mast {padding:8px 0 2px 0; text-align:center;}
.mast h1{margin:.2rem 0; font-size:1.15rem;}
.mast p{margin:0; color:#555; font-size:.9rem;}
/* cards */
.card{background:#fff;border:1px solid #eee;border-radius:14px;box-shadow:0 6px 18px rgba(0,0,0,.06);padding:14px;margin-top:10px;}
/* make input preview a neat square */
#input-img img{
width: 280px !important; height: 280px !important;
object-fit: contain !important; border-radius:12px;
}
#input-img .wrap{display:flex; justify-content:center;}
/* sample gallery tiles */
.sample-gallery img{
border-radius:10px;
}
/* space buttons a bit */
button{height:46px !important; font-size:1rem !important;}
#overview-img img {
width: 100% !important;
height: auto !important;
object-fit: contain !important;
border-radius: 10px;
margin-top: 10px;
}
"""
# ---------- Build UI ----------
with gr.Blocks(title="Acanthamoeba – Lite", theme=gr.themes.Soft(), css=css) as demo:
gr.HTML("""
<div class="mast">
<h1>Acanthamoeba Cyst Classifier (Lite)</h1>
<p>Upload / Capture → Predict G1–G3</p>
</div>
""")
# --- INPUT ROW: Upload (left) + 3 Samples (right) ---
gr.HTML('<div class="card">')
gr.Markdown("### Input / อินพุต")
with gr.Row():
# LEFT: Upload / Capture
with gr.Column(scale=1, min_width=200):
img_in = gr.Image(
type="pil",
sources=["upload", "webcam", "clipboard"],
label="Upload / Capture",
height=280, width=280,
image_mode="RGB",
elem_id="input-img"
)
go = gr.Button("Analyze", variant="primary")
# RIGHT: 3 Sample Images (responsive)
with gr.Column(scale=1, min_width=200):
gr.Markdown("**Samples / ตัวอย่าง**")
sample_paths = [
"examples/sample1.jpg",
"examples/sample5.jpg",
"examples/sample9.jpg",
]
sample_gallery = gr.Gallery(
value=sample_paths,
columns=[1, 3], # 1 column on phones, 3 on desktop
height=280,
object_fit="contain",
# preview=True,
elem_classes=["sample-gallery"],
)
def load_sample(evt: gr.SelectData):
idx = evt.index
if idx is None or idx < 0 or idx >= len(sample_paths):
return None
return Image.open(sample_paths[idx]).convert("RGB")
sample_gallery.select(
fn=load_sample, inputs=None, outputs=[img_in])
gr.HTML('</div>')
# --- RESULT CARD (immediately under inputs) ---
gr.HTML('<div class="card">')
out_text = gr.Markdown("**Prediction**: –")
out_conf = gr.Label(num_top_classes=3, label="Confidence (Top-3)")
gr.HTML('</div>')
# --- (Optional) Details accordion (NO card wrapper) ---
with gr.Accordion("Details / รายละเอียดระบบ", open=False):
gr.Markdown(
"- **Pipeline:** Resize 224×224 → ResNet50 (GAP 2048-D) → SVM (RBF)\n"
"- **Note:** Prototype demo; no images stored.\n"
)
if os.path.exists("overviewsystem.png"):
gr.Image(
value="overviewsystem.png",
show_label=False,
height=250,
elem_id="overview-img"
)
# Events: run on click AND when image changes
go.click(fn=infer, inputs=img_in, outputs=[out_text, out_conf])
img_in.change(fn=infer, inputs=img_in, outputs=[out_text, out_conf])
# ---------- Warmup (faster first request) ----------
try:
dummy = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
_ = predict(dummy)
except Exception:
pass
if __name__ == "__main__":
demo.launch(inbrowser=True)