idkWhatToUse commited on
Commit
bea780a
·
verified ·
1 Parent(s): d3d3727

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import json
7
+ from sentence_transformers import SentenceTransformer, util
8
+ from openai import OpenAI
9
+ import os
10
+
11
+ # ========== 1. Load Image Classification Model ==========
12
+ class MobileNetClassifier(nn.Module):
13
+ def __init__(self, num_classes=12):
14
+ super().__init__()
15
+ self.model = torch.hub.load("pytorch/vision:v0.10.0", "mobilenet_v2", pretrained=False)
16
+ self.model.classifier[1] = nn.Linear(1280, num_classes)
17
+
18
+ def forward(self, x):
19
+ return self.model(x)
20
+
21
+ # Load model
22
+ device = "cpu"
23
+ model = MobileNetClassifier(num_classes=12).to(device)
24
+ model.load_state_dict(torch.load("mobilenet_trash.pth", map_location=device))
25
+ model.eval()
26
+
27
+ labels = [
28
+ "cardboard", "glass", "metal", "paper", "plastic",
29
+ "trash", "battery", "shoes", "clothes", "green-glass",
30
+ "brown-glass", "white-glass"
31
+ ]
32
+
33
+ transform = transforms.Compose([
34
+ transforms.Resize((224,224)),
35
+ transforms.ToTensor()
36
+ ])
37
+
38
+ # ========== 2. Load QAS + Recycle Database ==========
39
+ qas = json.load(open("qas.json", "r", encoding="utf-8"))
40
+ recycle = json.load(open("recycle_data.json", "r", encoding="utf-8"))
41
+ recycle_dict = {item["name"]: item for item in recycle}
42
+
43
+ # Embedding Model
44
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
45
+ qas_questions = [item["question"] for item in qas]
46
+ qas_embeddings = embedder.encode(qas_questions, convert_to_tensor=True)
47
+
48
+ # ========== 3. Helper Functions ==========
49
+ def classify_image(image):
50
+ img = transform(image).unsqueeze(0)
51
+ with torch.no_grad():
52
+ pred = model(img)[0]
53
+ idx = torch.argmax(pred).item()
54
+ return labels[idx]
55
+
56
+ def search_recycle(label):
57
+ if label in recycle_dict:
58
+ item = recycle_dict[label]
59
+ return f"♻ {label}\n\n回收方式:{item['notes']}"
60
+ return f"⚠ 找不到「{label}」的回收資料,請依一般原則:可分離材質可回收。"
61
+
62
+ def rag_question(query):
63
+ emb = embedder.encode(query, convert_to_tensor=True)
64
+ scores = util.cos_sim(emb, qas_embeddings)[0]
65
+ idx = torch.argmax(scores).item()
66
+ score = scores[idx].item()
67
+
68
+ if score > 0.7:
69
+ return qas[idx]["answer"]
70
+ return None
71
+
72
+ client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
73
+
74
+ def llm_fallback(query):
75
+ msg = f"你是台灣垃圾分類助理,請回答:{query}"
76
+ resp = client.chat.completions.create(
77
+ model="gpt-4o-mini",
78
+ messages=[{"role": "user", "content": msg}]
79
+ )
80
+ return resp.choices[0].message["content"]
81
+
82
+ # ========== 4. Master Function ==========
83
+ def assistant(text, image):
84
+ if image is not None:
85
+ image = Image.fromarray(image)
86
+ label = classify_image(image)
87
+ result = search_recycle(label)
88
+ return f"🔍 辨識結果:{label}\n\n{result}"
89
+
90
+ if text:
91
+ rag_answer = rag_question(text)
92
+ if rag_answer:
93
+ return rag_answer
94
+ return llm_fallback(text)
95
+
96
+ return "請輸入問題或上傳圖片。"
97
+
98
+ # ========== 5. Gradio UI ==========
99
+ ui = gr.Interface(
100
+ fn=assistant,
101
+ inputs=[gr.Textbox(label="輸入問題"), gr.Image(type="numpy")],
102
+ outputs="text",
103
+ title="垃圾分類智慧助理",
104
+ description="上傳圖片或輸入問題,協助你判斷台灣垃圾分類方式"
105
+ )
106
+
107
+ ui.launch()