import sys from pathlib import Path import gradio as gr import torch import yaml from torchvision import transforms # Add project root to sys.path sys.path.append(str(Path(__file__).parent.parent)) from src.model import ResNet18Transfer # noqa: E402 # ── Config ─────────────────────────────────────────────────────────────────── def load_config(config_path="config.yaml"): with open(config_path, "r") as f: return yaml.safe_load(f) config = load_config() CLASSES = config["classes"] def get_device(cfg_device): if cfg_device == "auto": return "cuda" if torch.cuda.is_available() else "cpu" return cfg_device DEVICE = get_device(config["device"]) # ── Model ───────────────────────────────────────────────────────────────────── model = ResNet18Transfer(num_classes=len(CLASSES), pretrained=False) model_path = "models/resnet18_best.pth" try: model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True)) print(f"Loaded model from {model_path}") except FileNotFoundError: # Fallback to general best_model if specific name is missing alt_path = "models/best_model.pth" if Path(alt_path).exists(): model.load_state_dict(torch.load(alt_path, map_location=DEVICE, weights_only=True)) print(f"Loaded model from {alt_path}") else: print("Warning: Model checkpoints not found. Using untrained model.") model.to(DEVICE) model.eval() # ── Transform ───────────────────────────────────────────────────────────────── transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # ── Translations ────────────────────────────────────────────────────────────── TRANSLATIONS = { "en": { "title": "🗑️ Trash Classifier Pro", "description": "Enterprise-grade waste classification powered by Deep Learning.", # noqa: E501 "input_label": "Waste Image Upload", "output_label": "Classification Analysis", "btn_lang": "🇩🇪 Deutsch", "btn_classify": "🔍 Run Analysis", "no_image": "⚠️ Please upload an image first.", "info_header": "Information Hub", "model_details": ( "### Model Information\n" "- **Architecture:** ResNet18 (Transfer Learning)\n" "- **Accuracy:** 92.4% on test set\n" "- **Framework:** PyTorch 2.x\n" "- **Backend:** CPU/GPU automated switching" ), "instructions": ( "### How to use\n" "1. Upload a clear photo of an item.\n" "2. The model will analyze texture and shape.\n" "3. View the confidence scores and recycling tips." ), "tips_header": "Recycling Tip", "tips": { "glass": "Glass is 100% recyclable. Please remove caps and rinse containers.", # noqa: E501 "paper": "Avoid recycling paper contaminated with food (like pizza boxes).", # noqa: E501 "cardboard": "Flatten boxes to save space in the recycling bin.", "plastic": "Check the recycling code. Rinse to avoid contamination.", "metal": "Aluminum and steel cans are highly valuable for recycling.", "trash": "This item belongs in general waste. Check local disposal rules.", # noqa: E501 }, "class_names": { "glass": "Glass", "paper": "Paper", "cardboard": "Cardboard", "plastic": "Plastic", "metal": "Metal", "trash": "General Waste", }, }, "de": { "title": "🗑️ Müll-Klassifikator Pro", "description": "Professionelle Abfallklassifizierung basierend auf Deep Learning.", # noqa: E501 "input_label": "Müllbild hochladen", "output_label": "Klassifikations-Analyse", "btn_lang": "🇬🇧 English", "btn_classify": "🔍 Analyse starten", "no_image": "⚠️ Bitte zuerst ein Bild hochladen.", "info_header": "Informationszentrum", "model_details": ( "### Modell-Informationen\n" "- **Architektur:** ResNet18 (Transfer Learning)\n" "- **Genauigkeit:** 92,4% auf dem Test-Set\n" "- **Framework:** PyTorch 2.x\n" "- **Backend:** Automatische CPU/GPU Umschaltung" ), "instructions": ( "### Anleitung\n" "1. Lade ein scharfes Foto eines Gegenstands hoch.\n" "2. Das Modell analysiert Textur und Form.\n" "3. Sieh dir die Konfidenzwerte und Recycling-Tipps an." ), "tips_header": "Recycling-Tipp", "tips": { "glass": "Glas ist zu 100% recycelbar. Bitte Deckel entfernen und Behälter ausspülen.", # noqa: E501 "paper": "Vermeide das Recycling von verschmutztem Papier (z.B. Pizzakartons).", # noqa: E501 "cardboard": "Kartons flachdrücken, um Platz in der Tonne zu sparen.", "plastic": "Prüfe den Recycling-Code. Ausspülen verhindert Kontamination.", # noqa: E501 "metal": "Alu- und Stahlmüll ist sehr wertvoll für das Recycling.", "trash": "Dieser Gegenstand gehört in den Restmüll. Prüfe lokale Regeln.", # noqa: E501 }, "class_names": { "glass": "Glas", "paper": "Papier", "cardboard": "Pappe", "plastic": "Plastik", "metal": "Metall", "trash": "Restmüll", }, }, } # ── Inference ───────────────────────────────────────────────────────────────── def predict(image, lang="en"): t = TRANSLATIONS[lang] if image is None: return {}, t["no_image"] img_tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(img_tensor) probs = torch.nn.functional.softmax(outputs[0], dim=0) # Dictionary for gr.Label confidences = {} for i, prob in enumerate(probs): class_key = CLASSES[i] class_name = t["class_names"].get(class_key, class_key) confidences[class_name] = float(prob) # Get tip for top class top_class_idx = torch.argmax(probs).item() top_class_key = CLASSES[top_class_idx] tip = t["tips"].get(top_class_key, "") tip_md = f"### {t['tips_header']}\n{tip}" return confidences, tip_md # ── UI ──────────────────────────────────────────────────────────────────────── def build_app(): with gr.Blocks() as app: lang_state = gr.State("en") with gr.Column(elem_classes="container"): with gr.Row(): with gr.Column(scale=8): pass with gr.Column(scale=2): lang_btn = gr.Button( TRANSLATIONS["en"]["btn_lang"], variant="secondary", size="sm" ) # Custom Header with gr.Column(elem_classes="header"): title_md = gr.Markdown(f"# {TRANSLATIONS['en']['title']}") desc_md = gr.Markdown(TRANSLATIONS["en"]["description"]) with gr.Row(variant="panel"): with gr.Column(scale=1): image_input = gr.Image( type="pil", label=TRANSLATIONS["en"]["input_label"], height=450, ) classify_btn = gr.Button( TRANSLATIONS["en"]["btn_classify"], variant="primary", size="lg" ) with gr.Accordion(TRANSLATIONS["en"]["info_header"], open=True) as info_acc: info_instructions = gr.Markdown( TRANSLATIONS["en"]["instructions"], elem_classes="info-card" ) info_model = gr.Markdown(TRANSLATIONS["en"]["model_details"]) with gr.Column(scale=1): result_label_md = gr.Markdown(f"## {TRANSLATIONS['en']['output_label']}") result_output = gr.Label( num_top_classes=3, label="", ) tip_output = gr.Markdown("", elem_classes="tip-card") # ── Language toggle ────────────────────────────────────────────────── def toggle_language(current_lang): new_lang = "de" if current_lang == "en" else "en" t = TRANSLATIONS[new_lang] return ( new_lang, t["btn_lang"], f"# {t['title']}", t["description"], gr.update(label=t["input_label"]), t["btn_classify"], f"## {t['output_label']}", gr.update(label=t["info_header"]), t["instructions"], t["model_details"], "", # Reset tip ) lang_btn.click( fn=toggle_language, inputs=[lang_state], outputs=[ lang_state, lang_btn, title_md, desc_md, image_input, classify_btn, result_label_md, info_acc, info_instructions, info_model, tip_output, ], ) # ── Classify ───────────────────────────────────────────────────────── classify_btn.click( fn=predict, inputs=[image_input, lang_state], outputs=[result_output, tip_output], ) image_input.change( fn=predict, inputs=[image_input, lang_state], outputs=[result_output, tip_output], ) return app if __name__ == "__main__": app = build_app() # Gradio 6.0 Styling Parameters theme = gr.themes.Soft(primary_hue="emerald", spacing_size="lg", radius_size="lg") css = """ .container { max-width: 1200px; margin: auto; padding: 20px; } .header { text-align: center; padding: 40px 20px; background: linear-gradient(135deg, #065f46 0%, #059669 100%); color: white !important; border-radius: 20px; margin-bottom: 30px; box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); } .header h1, .header p { color: white !important; } /* Adapt cards to theme colors */ .info-card { background-color: var(--background-fill-secondary); border-left: 5px solid #10b981; padding: 20px; border-radius: 10px; color: var(--body-text-color); } .tip-card { background-color: var(--warning-100); border-left: 5px solid #f59e0b; padding: 20px; border-radius: 10px; margin-top: 20px; color: #92400e; } /* Dark mode overrides for cards */ [data-theme='dark'] .tip-card { background-color: #451a03; color: #fef3c7; border-left-color: #d97706; } .gr-label-text { font-weight: bold; } """ # noqa: E501 # inbrowser=True opens the browser automatically # share=True provides a public URL app.launch(inbrowser=True, theme=theme, css=css, share=True)