import torch import gradio as gr from PIL import Image import json import clip from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import hf_hub_download import numpy as np # ========================================================= # 1. 讀取台南垃圾分類資料庫 JSON # ========================================================= with open("tainan_recycle_data.json", "r", encoding="utf-8") as f: tainan_db = json.load(f) def lookup_item(name): """模糊比對英文/中文項目""" name = name.lower() for item in tainan_db: # 中文名稱比對 if name in item["name"].lower(): return item # 英文名稱比對 if name in item["english_name"].lower(): return item return None # ========================================================= # 2. 載入模型:GroundingDINO (從你的 Dataset Repo) # ========================================================= # ⚠️ 請把這裡的 repo_id 換成你的 HuggingFace Dataset REPO_ID = "idkWhatToUse/groundingdino-weights" config_path = hf_hub_download( repo_id=REPO_ID, filename="GroundingDINO_SwinT_OGC.py", repo_type="dataset" ) checkpoint_path = hf_hub_download( repo_id=REPO_ID, filename="groundingdino_swint_ogc.pth", repo_type="dataset" ) from groundingdino.util.inference import load_model, load_image, predict, annotate dino_model = load_model(config_path, checkpoint_path) # ========================================================= # 3. 載入 CLIP(分類判斷)+ LLM(分類理由) # ========================================================= device = "cuda" if torch.cuda.is_available() else "cpu" clip_model, preprocess = clip.load("ViT-B/32", device=device) labels = ["一般垃圾", "紙類", "塑膠類", "金屬類", "玻璃類", "食物垃圾", "電池", "電子產品"] # LLM (phi-2) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") llm = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device) def generate_reason(item_name, category): prompt = f"物品「{item_name}」被歸類為「{category}」。請用一句話簡單解釋原因:" inputs = tokenizer(prompt, return_tensors="pt").to(device) outputs = llm.generate(**inputs, max_new_tokens=40) return tokenizer.decode(outputs[0], skip_special_tokens=True) # ========================================================= # 4. CLIP 分類(查不到官方資料時使用) # ========================================================= def classify_clip(image, obj_name): text_inputs = clip.tokenize(labels).to(device) img_tensor = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): logits, _ = clip_model(img_tensor, text_inputs) probs = logits.softmax(dim=-1).cpu().numpy()[0] idx = probs.argmax() category = labels[idx] reason = generate_reason(obj_name, category) return category, reason # ========================================================= # 5. 主流程:物件偵測 → 查資料庫 → 顯示官方資訊 # ========================================================= def pipeline(image): # Gradio 會回傳 numpy array,因此先轉換 if isinstance(image, np.ndarray): image = Image.fromarray(image.astype("uint8")) image.save("temp.jpg") # 1. GroundingDINO 偵測物品 img_np, img_tensor = load_image("temp.jpg") boxes, logits, phrases = predict( model=dino_model, image=img_tensor, caption="bottle, can, box, plastic, cup, phone, battery, appliance, metal", box_threshold=0.3, text_threshold=0.25 ) annotated = annotate(img_np, boxes, logits, phrases) # 2. 組合輸出 result_text = "" for obj in phrases: obj_clean = obj.lower() # (A)先查台南官方資料庫 match = lookup_item(obj_clean) if match: result_text += f""" 🧩 **物品:{match['name']}** 📘 英文名稱:{match['english_name']} ♻️ 回收指示:{match['recyclable']} 📖 官方說明:{match['notes']} 🌐 資料來源:台南市政府環保局 """ continue # (B)查不到 → 用 CLIP 推論 + LLM 理由 category, reason = classify_clip(image, obj_clean) result_text += f""" 🧩 偵測到:{obj} 📦 分類推論:{category} 💡 理由:{reason} """ return annotated, result_text # ========================================================= # 6. Gradio 介面 # ========================================================= with gr.Blocks() as demo: gr.Markdown("# 🗑️ AI 垃圾分類助手(含台南官方資料庫)") gr.Markdown("上傳照片 → 物件偵測 → 查台南市環保局資料庫 → 顯示回收建議") img_input = gr.Image(label="上傳垃圾照片") btn = gr.Button("開始分析 🚀", variant="primary") img_output = gr.Image(label="物件偵測(GroundingDINO)") text_output = gr.Textbox(label="分類結果", lines=10) btn.click(pipeline, inputs=[img_input], outputs=[img_output, text_output]) demo.launch()