Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import yaml | |
| from torchvision import transforms | |
| # Add project root to sys.path | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from src.model import ResNet18Transfer # noqa: E402 | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_config(config_path="config.yaml"): | |
| with open(config_path, "r") as f: | |
| return yaml.safe_load(f) | |
| config = load_config() | |
| CLASSES = config["classes"] | |
| def get_device(cfg_device): | |
| if cfg_device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return cfg_device | |
| DEVICE = get_device(config["device"]) | |
| # ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = ResNet18Transfer(num_classes=len(CLASSES), pretrained=False) | |
| model_path = "models/resnet18_best.pth" | |
| try: | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True)) | |
| print(f"Loaded model from {model_path}") | |
| except FileNotFoundError: | |
| # Fallback to general best_model if specific name is missing | |
| alt_path = "models/best_model.pth" | |
| if Path(alt_path).exists(): | |
| model.load_state_dict(torch.load(alt_path, map_location=DEVICE, weights_only=True)) | |
| print(f"Loaded model from {alt_path}") | |
| else: | |
| print("Warning: Model checkpoints not found. Using untrained model.") | |
| model.to(DEVICE) | |
| model.eval() | |
| # ββ Transform βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # ββ Translations ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TRANSLATIONS = { | |
| "en": { | |
| "title": "ποΈ Trash Classifier Pro", | |
| "description": "Enterprise-grade waste classification powered by Deep Learning.", # noqa: E501 | |
| "input_label": "Waste Image Upload", | |
| "output_label": "Classification Analysis", | |
| "btn_lang": "π©πͺ Deutsch", | |
| "btn_classify": "π Run Analysis", | |
| "no_image": "β οΈ Please upload an image first.", | |
| "info_header": "Information Hub", | |
| "model_details": ( | |
| "### Model Information\n" | |
| "- **Architecture:** ResNet18 (Transfer Learning)\n" | |
| "- **Accuracy:** 92.4% on test set\n" | |
| "- **Framework:** PyTorch 2.x\n" | |
| "- **Backend:** CPU/GPU automated switching" | |
| ), | |
| "instructions": ( | |
| "### How to use\n" | |
| "1. Upload a clear photo of an item.\n" | |
| "2. The model will analyze texture and shape.\n" | |
| "3. View the confidence scores and recycling tips." | |
| ), | |
| "tips_header": "Recycling Tip", | |
| "tips": { | |
| "glass": "Glass is 100% recyclable. Please remove caps and rinse containers.", # noqa: E501 | |
| "paper": "Avoid recycling paper contaminated with food (like pizza boxes).", # noqa: E501 | |
| "cardboard": "Flatten boxes to save space in the recycling bin.", | |
| "plastic": "Check the recycling code. Rinse to avoid contamination.", | |
| "metal": "Aluminum and steel cans are highly valuable for recycling.", | |
| "trash": "This item belongs in general waste. Check local disposal rules.", # noqa: E501 | |
| }, | |
| "class_names": { | |
| "glass": "Glass", | |
| "paper": "Paper", | |
| "cardboard": "Cardboard", | |
| "plastic": "Plastic", | |
| "metal": "Metal", | |
| "trash": "General Waste", | |
| }, | |
| }, | |
| "de": { | |
| "title": "ποΈ MΓΌll-Klassifikator Pro", | |
| "description": "Professionelle Abfallklassifizierung basierend auf Deep Learning.", # noqa: E501 | |
| "input_label": "MΓΌllbild hochladen", | |
| "output_label": "Klassifikations-Analyse", | |
| "btn_lang": "π¬π§ English", | |
| "btn_classify": "π Analyse starten", | |
| "no_image": "β οΈ Bitte zuerst ein Bild hochladen.", | |
| "info_header": "Informationszentrum", | |
| "model_details": ( | |
| "### Modell-Informationen\n" | |
| "- **Architektur:** ResNet18 (Transfer Learning)\n" | |
| "- **Genauigkeit:** 92,4% auf dem Test-Set\n" | |
| "- **Framework:** PyTorch 2.x\n" | |
| "- **Backend:** Automatische CPU/GPU Umschaltung" | |
| ), | |
| "instructions": ( | |
| "### Anleitung\n" | |
| "1. Lade ein scharfes Foto eines Gegenstands hoch.\n" | |
| "2. Das Modell analysiert Textur und Form.\n" | |
| "3. Sieh dir die Konfidenzwerte und Recycling-Tipps an." | |
| ), | |
| "tips_header": "Recycling-Tipp", | |
| "tips": { | |
| "glass": "Glas ist zu 100% recycelbar. Bitte Deckel entfernen und BehΓ€lter ausspΓΌlen.", # noqa: E501 | |
| "paper": "Vermeide das Recycling von verschmutztem Papier (z.B. Pizzakartons).", # noqa: E501 | |
| "cardboard": "Kartons flachdrΓΌcken, um Platz in der Tonne zu sparen.", | |
| "plastic": "PrΓΌfe den Recycling-Code. AusspΓΌlen verhindert Kontamination.", # noqa: E501 | |
| "metal": "Alu- und StahlmΓΌll ist sehr wertvoll fΓΌr das Recycling.", | |
| "trash": "Dieser Gegenstand gehΓΆrt in den RestmΓΌll. PrΓΌfe lokale Regeln.", # noqa: E501 | |
| }, | |
| "class_names": { | |
| "glass": "Glas", | |
| "paper": "Papier", | |
| "cardboard": "Pappe", | |
| "plastic": "Plastik", | |
| "metal": "Metall", | |
| "trash": "RestmΓΌll", | |
| }, | |
| }, | |
| } | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(image, lang="en"): | |
| t = TRANSLATIONS[lang] | |
| if image is None: | |
| return {}, t["no_image"] | |
| img_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probs = torch.nn.functional.softmax(outputs[0], dim=0) | |
| # Dictionary for gr.Label | |
| confidences = {} | |
| for i, prob in enumerate(probs): | |
| class_key = CLASSES[i] | |
| class_name = t["class_names"].get(class_key, class_key) | |
| confidences[class_name] = float(prob) | |
| # Get tip for top class | |
| top_class_idx = torch.argmax(probs).item() | |
| top_class_key = CLASSES[top_class_idx] | |
| tip = t["tips"].get(top_class_key, "") | |
| tip_md = f"### {t['tips_header']}\n{tip}" | |
| return confidences, tip_md | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_app(): | |
| with gr.Blocks() as app: | |
| lang_state = gr.State("en") | |
| with gr.Column(elem_classes="container"): | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| pass | |
| with gr.Column(scale=2): | |
| lang_btn = gr.Button( | |
| TRANSLATIONS["en"]["btn_lang"], variant="secondary", size="sm" | |
| ) | |
| # Custom Header | |
| with gr.Column(elem_classes="header"): | |
| title_md = gr.Markdown(f"# {TRANSLATIONS['en']['title']}") | |
| desc_md = gr.Markdown(TRANSLATIONS["en"]["description"]) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| type="pil", | |
| label=TRANSLATIONS["en"]["input_label"], | |
| height=450, | |
| ) | |
| classify_btn = gr.Button( | |
| TRANSLATIONS["en"]["btn_classify"], variant="primary", size="lg" | |
| ) | |
| with gr.Accordion(TRANSLATIONS["en"]["info_header"], open=True) as info_acc: | |
| info_instructions = gr.Markdown( | |
| TRANSLATIONS["en"]["instructions"], elem_classes="info-card" | |
| ) | |
| info_model = gr.Markdown(TRANSLATIONS["en"]["model_details"]) | |
| with gr.Column(scale=1): | |
| result_label_md = gr.Markdown(f"## {TRANSLATIONS['en']['output_label']}") | |
| result_output = gr.Label( | |
| num_top_classes=3, | |
| label="", | |
| ) | |
| tip_output = gr.Markdown("", elem_classes="tip-card") | |
| # ββ Language toggle ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def toggle_language(current_lang): | |
| new_lang = "de" if current_lang == "en" else "en" | |
| t = TRANSLATIONS[new_lang] | |
| return ( | |
| new_lang, | |
| t["btn_lang"], | |
| f"# {t['title']}", | |
| t["description"], | |
| gr.update(label=t["input_label"]), | |
| t["btn_classify"], | |
| f"## {t['output_label']}", | |
| gr.update(label=t["info_header"]), | |
| t["instructions"], | |
| t["model_details"], | |
| "", # Reset tip | |
| ) | |
| lang_btn.click( | |
| fn=toggle_language, | |
| inputs=[lang_state], | |
| outputs=[ | |
| lang_state, | |
| lang_btn, | |
| title_md, | |
| desc_md, | |
| image_input, | |
| classify_btn, | |
| result_label_md, | |
| info_acc, | |
| info_instructions, | |
| info_model, | |
| tip_output, | |
| ], | |
| ) | |
| # ββ Classify βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| classify_btn.click( | |
| fn=predict, | |
| inputs=[image_input, lang_state], | |
| outputs=[result_output, tip_output], | |
| ) | |
| image_input.change( | |
| fn=predict, | |
| inputs=[image_input, lang_state], | |
| outputs=[result_output, tip_output], | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = build_app() | |
| # Gradio 6.0 Styling Parameters | |
| theme = gr.themes.Soft(primary_hue="emerald", spacing_size="lg", radius_size="lg") | |
| css = """ | |
| .container { max-width: 1200px; margin: auto; padding: 20px; } | |
| .header { | |
| text-align: center; | |
| padding: 40px 20px; | |
| background: linear-gradient(135deg, #065f46 0%, #059669 100%); | |
| color: white !important; | |
| border-radius: 20px; | |
| margin-bottom: 30px; | |
| box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); | |
| } | |
| .header h1, .header p { color: white !important; } | |
| /* Adapt cards to theme colors */ | |
| .info-card { | |
| background-color: var(--background-fill-secondary); | |
| border-left: 5px solid #10b981; | |
| padding: 20px; | |
| border-radius: 10px; | |
| color: var(--body-text-color); | |
| } | |
| .tip-card { | |
| background-color: var(--warning-100); | |
| border-left: 5px solid #f59e0b; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-top: 20px; | |
| color: #92400e; | |
| } | |
| /* Dark mode overrides for cards */ | |
| [data-theme='dark'] .tip-card { | |
| background-color: #451a03; | |
| color: #fef3c7; | |
| border-left-color: #d97706; | |
| } | |
| .gr-label-text { font-weight: bold; } | |
| """ # noqa: E501 | |
| # inbrowser=True opens the browser automatically | |
| # share=True provides a public URL | |
| app.launch(inbrowser=True, theme=theme, css=css, share=True) | |