| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| import os |
|
|
| |
| from model_resnet50 import predict |
|
|
| |
| LABELS_STR = {"G1": "Cyst – G1", "G2": "Cyst – G2", "G3": "Cyst – G3"} |
| LABELS_INT = {0: "Cyst – G1", 1: "Cyst – G2", 2: "Cyst – G3"} |
|
|
|
|
| def _pretty(lbl): |
| if isinstance(lbl, str) and lbl in LABELS_STR: |
| return LABELS_STR[lbl] |
| if isinstance(lbl, (int, np.integer)) and int(lbl) in LABELS_INT: |
| return LABELS_INT[int(lbl)] |
| return str(lbl) |
|
|
|
|
| def infer(image: Image.Image): |
| out = predict(image) |
| name = _pretty(out["label"]) |
| probs = out.get("probs") |
| conf = None |
| if isinstance(probs, dict): |
| conf = {_pretty(k): float(v) for k, v in probs.items()} |
| text = f"**Prediction**: {name}" |
| return text, conf |
|
|
|
|
| |
| css = """ |
| footer{visibility:hidden;} |
| /* narrow by default for phones; expand on desktop */ |
| .gradio-container{max-width: 460px !important; margin:auto;} |
| @media (min-width: 900px){ .gradio-container{max-width: 880px !important;} } |
| |
| /* simple masthead */ |
| .mast {padding:8px 0 2px 0; text-align:center;} |
| .mast h1{margin:.2rem 0; font-size:1.15rem;} |
| .mast p{margin:0; color:#555; font-size:.9rem;} |
| |
| /* cards */ |
| .card{background:#fff;border:1px solid #eee;border-radius:14px;box-shadow:0 6px 18px rgba(0,0,0,.06);padding:14px;margin-top:10px;} |
| |
| /* make input preview a neat square */ |
| #input-img img{ |
| width: 280px !important; height: 280px !important; |
| object-fit: contain !important; border-radius:12px; |
| } |
| #input-img .wrap{display:flex; justify-content:center;} |
| |
| /* sample gallery tiles */ |
| .sample-gallery img{ |
| border-radius:10px; |
| } |
| |
| /* space buttons a bit */ |
| button{height:46px !important; font-size:1rem !important;} |
| |
| #overview-img img { |
| width: 100% !important; |
| height: auto !important; |
| object-fit: contain !important; |
| border-radius: 10px; |
| margin-top: 10px; |
| } |
| """ |
|
|
| |
| with gr.Blocks(title="Acanthamoeba – Lite", theme=gr.themes.Soft(), css=css) as demo: |
| gr.HTML(""" |
| <div class="mast"> |
| <h1>Acanthamoeba Cyst Classifier (Lite)</h1> |
| <p>Upload / Capture → Predict G1–G3</p> |
| </div> |
| """) |
|
|
| |
| gr.HTML('<div class="card">') |
| gr.Markdown("### Input / อินพุต") |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1, min_width=200): |
| img_in = gr.Image( |
| type="pil", |
| sources=["upload", "webcam", "clipboard"], |
| label="Upload / Capture", |
| height=280, width=280, |
| image_mode="RGB", |
| elem_id="input-img" |
| ) |
| go = gr.Button("Analyze", variant="primary") |
|
|
| |
| with gr.Column(scale=1, min_width=200): |
| gr.Markdown("**Samples / ตัวอย่าง**") |
| sample_paths = [ |
| "examples/sample1.jpg", |
| "examples/sample5.jpg", |
| "examples/sample9.jpg", |
| ] |
| sample_gallery = gr.Gallery( |
| value=sample_paths, |
| columns=[1, 3], |
| height=280, |
| object_fit="contain", |
| |
| elem_classes=["sample-gallery"], |
| ) |
|
|
| def load_sample(evt: gr.SelectData): |
| idx = evt.index |
| if idx is None or idx < 0 or idx >= len(sample_paths): |
| return None |
| return Image.open(sample_paths[idx]).convert("RGB") |
|
|
| sample_gallery.select( |
| fn=load_sample, inputs=None, outputs=[img_in]) |
| gr.HTML('</div>') |
|
|
| |
| gr.HTML('<div class="card">') |
| out_text = gr.Markdown("**Prediction**: –") |
| out_conf = gr.Label(num_top_classes=3, label="Confidence (Top-3)") |
| gr.HTML('</div>') |
|
|
| |
| with gr.Accordion("Details / รายละเอียดระบบ", open=False): |
| gr.Markdown( |
| "- **Pipeline:** Resize 224×224 → ResNet50 (GAP 2048-D) → SVM (RBF)\n" |
| "- **Note:** Prototype demo; no images stored.\n" |
| ) |
| if os.path.exists("overviewsystem.png"): |
| gr.Image( |
| value="overviewsystem.png", |
| show_label=False, |
| height=250, |
| elem_id="overview-img" |
| ) |
| |
| go.click(fn=infer, inputs=img_in, outputs=[out_text, out_conf]) |
| img_in.change(fn=infer, inputs=img_in, outputs=[out_text, out_conf]) |
|
|
| |
| try: |
| dummy = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)) |
| _ = predict(dummy) |
| except Exception: |
| pass |
|
|
| if __name__ == "__main__": |
| demo.launch(inbrowser=True) |
|
|