Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import json | |
| import clip | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| # ========================================================= | |
| # 1. 讀取台南垃圾分類資料庫 JSON | |
| # ========================================================= | |
| with open("tainan_recycle_data.json", "r", encoding="utf-8") as f: | |
| tainan_db = json.load(f) | |
| def lookup_item(name): | |
| """模糊比對英文/中文項目""" | |
| name = name.lower() | |
| for item in tainan_db: | |
| # 中文名稱比對 | |
| if name in item["name"].lower(): | |
| return item | |
| # 英文名稱比對 | |
| if name in item["english_name"].lower(): | |
| return item | |
| return None | |
| # ========================================================= | |
| # 2. 載入模型:GroundingDINO (從你的 Dataset Repo) | |
| # ========================================================= | |
| # ⚠️ 請把這裡的 repo_id 換成你的 HuggingFace Dataset | |
| REPO_ID = "idkWhatToUse/groundingdino-weights" | |
| config_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="GroundingDINO_SwinT_OGC.py", | |
| repo_type="dataset" | |
| ) | |
| checkpoint_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="groundingdino_swint_ogc.pth", | |
| repo_type="dataset" | |
| ) | |
| from groundingdino.util.inference import load_model, load_image, predict, annotate | |
| dino_model = load_model(config_path, checkpoint_path) | |
| # ========================================================= | |
| # 3. 載入 CLIP(分類判斷)+ LLM(分類理由) | |
| # ========================================================= | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| clip_model, preprocess = clip.load("ViT-B/32", device=device) | |
| labels = ["一般垃圾", "紙類", "塑膠類", "金屬類", "玻璃類", "食物垃圾", "電池", "電子產品"] | |
| # LLM (phi-2) | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") | |
| llm = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device) | |
| def generate_reason(item_name, category): | |
| prompt = f"物品「{item_name}」被歸類為「{category}」。請用一句話簡單解釋原因:" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| outputs = llm.generate(**inputs, max_new_tokens=40) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # ========================================================= | |
| # 4. CLIP 分類(查不到官方資料時使用) | |
| # ========================================================= | |
| def classify_clip(image, obj_name): | |
| text_inputs = clip.tokenize(labels).to(device) | |
| img_tensor = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits, _ = clip_model(img_tensor, text_inputs) | |
| probs = logits.softmax(dim=-1).cpu().numpy()[0] | |
| idx = probs.argmax() | |
| category = labels[idx] | |
| reason = generate_reason(obj_name, category) | |
| return category, reason | |
| # ========================================================= | |
| # 5. 主流程:物件偵測 → 查資料庫 → 顯示官方資訊 | |
| # ========================================================= | |
| def pipeline(image): | |
| # Gradio 會回傳 numpy array,因此先轉換 | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype("uint8")) | |
| image.save("temp.jpg") | |
| # 1. GroundingDINO 偵測物品 | |
| img_np, img_tensor = load_image("temp.jpg") | |
| boxes, logits, phrases = predict( | |
| model=dino_model, | |
| image=img_tensor, | |
| caption="bottle, can, box, plastic, cup, phone, battery, appliance, metal", | |
| box_threshold=0.3, | |
| text_threshold=0.25 | |
| ) | |
| annotated = annotate(img_np, boxes, logits, phrases) | |
| # 2. 組合輸出 | |
| result_text = "" | |
| for obj in phrases: | |
| obj_clean = obj.lower() | |
| # (A)先查台南官方資料庫 | |
| match = lookup_item(obj_clean) | |
| if match: | |
| result_text += f""" | |
| 🧩 **物品:{match['name']}** | |
| 📘 英文名稱:{match['english_name']} | |
| ♻️ 回收指示:{match['recyclable']} | |
| 📖 官方說明:{match['notes']} | |
| 🌐 資料來源:台南市政府環保局 | |
| """ | |
| continue | |
| # (B)查不到 → 用 CLIP 推論 + LLM 理由 | |
| category, reason = classify_clip(image, obj_clean) | |
| result_text += f""" | |
| 🧩 偵測到:{obj} | |
| 📦 分類推論:{category} | |
| 💡 理由:{reason} | |
| """ | |
| return annotated, result_text | |
| # ========================================================= | |
| # 6. Gradio 介面 | |
| # ========================================================= | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🗑️ AI 垃圾分類助手(含台南官方資料庫)") | |
| gr.Markdown("上傳照片 → 物件偵測 → 查台南市環保局資料庫 → 顯示回收建議") | |
| img_input = gr.Image(label="上傳垃圾照片") | |
| btn = gr.Button("開始分析 🚀", variant="primary") | |
| img_output = gr.Image(label="物件偵測(GroundingDINO)") | |
| text_output = gr.Textbox(label="分類結果", lines=10) | |
| btn.click(pipeline, inputs=[img_input], outputs=[img_output, text_output]) | |
| demo.launch() | |