idkWhatToUse commited on
Commit
44fdde6
·
verified ·
1 Parent(s): 411a6ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -160
app.py CHANGED
@@ -3,188 +3,147 @@ import json
3
  import torch
4
  from PIL import Image
5
  from sentence_transformers import SentenceTransformer, util
6
- from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM
 
 
 
7
 
8
- # ========== 1. 載入回收物品資料 ==========
9
- with open("recycle_data.json", "r", encoding="utf-8") as f:
10
- recycle_data = json.load(f)
 
11
 
12
- # 271 個 label 文本(給 CLIP 用)
13
  label_texts = []
14
- id_to_item = []
15
-
16
  for item in recycle_data:
17
- zh = item.get("name", "").strip()
18
- en = (item.get("english_name") or "").strip()
19
- if en:
20
- text = f"{en}, {zh}"
21
- else:
22
- text = zh
23
- label_texts.append(text)
24
- id_to_item.append(item)
25
-
26
- num_labels = len(label_texts)
27
- print(f"Loaded {num_labels} recycle labels")
28
-
29
- # ========== 2. 載入 Q&A 資料 ==========
30
- with open("qas.json", "r", encoding="utf-8") as f:
31
- qas = json.load(f)
32
-
33
  qa_questions = [q["question"] for q in qas]
34
 
35
  embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
36
  qa_embeddings = embedder.encode(qa_questions, convert_to_tensor=True)
37
 
38
- # ========== 3. 載入 CLIP 模型 (Zero-shot 圖片 → 文本) ==========
39
- clip_model_name = "openai/clip-vit-base-patch32"
40
- clip_model = CLIPModel.from_pretrained(clip_model_name)
41
- clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
 
42
 
43
- # 預先把 label 文本 embed(可加速)
44
  with torch.no_grad():
45
- text_inputs = clip_processor(
46
- text=label_texts,
47
- images=None,
48
- return_tensors="pt",
49
- padding=True
 
50
  )
51
- text_embeds = clip_model.get_text_features(**{k: v for k, v in text_inputs.items() if k.startswith("input_ids") or k.startswith("attention_mask")})
52
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
53
 
54
- # ========== 4. (可選)小型 LLM 作 fallback ==========
55
- # Space 免費 CPU 撐不住,可以先註解掉這段,或回傳簡單文字
56
- llm_name = "microsoft/phi-2"
57
- tokenizer = AutoTokenizer.from_pretrained(llm_name)
58
- llm_model = AutoModelForCausalLM.from_pretrained(llm_name)
59
-
60
- def llm_fallback(query: str) -> str:
61
- prompt = f"你是一位垃圾分類助理,請用簡單中文回答以下問題,並遵守常見垃圾分類規則:{query}"
62
- inputs = tokenizer(prompt, return_tensors="pt")
63
- outputs = llm_model.generate(**inputs, max_new_tokens=120)
64
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
65
-
66
- # ========== 5. 工具函式 ==========
67
- def classify_image_with_clip(pil_image: Image.Image):
68
- # images + text → CLIP 相似度
69
- inputs = clip_processor(
70
- text=None,
71
- images=pil_image,
72
- return_tensors="pt"
73
  )
 
74
 
 
 
 
 
 
75
  with torch.no_grad():
76
- image_embeds = clip_model.get_image_features(**inputs)
77
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
78
-
79
- # cosine similarity
80
- logits = image_embeds @ text_embeds.T # (1, num_labels)
81
  probs = logits.softmax(dim=-1)[0]
82
-
83
  score, idx = torch.max(probs, dim=-1)
84
- score = float(score.item())
85
- idx = int(idx.item())
86
- return idx, score
87
-
88
- def build_recycle_answer(item, score):
89
- name = item.get("name", "")
90
- en = item.get("english_name", "")
91
- notes = item.get("notes", "")
92
- rec = item.get("recyclable", "")
93
-
94
- header = f"🔍 我推測此物品最接近:**{name}**"
95
- if en:
96
- header += f"({en})"
97
- header += f"\n相似度:約 {score:.2f}\n\n"
98
-
99
- body = ""
100
- if rec:
101
- body += f"♻️ 是否可回收 / 類型:{rec}\n\n"
102
- if notes:
103
- body += f"📦 建議回收方式:\n{notes}\n"
104
- else:
105
- body += "目前沒有更詳細的回收說明,可依一般回收原則處理。"
106
-
107
- return header + body
108
-
109
- def generic_recycle_hint():
110
- return (
111
- "❓ 我無法自信地判斷這是資料庫中的哪一項物品。\n\n"
112
- "可以參考以下一般原則:\n"
113
- "1. 乾淨、可分離的紙類、塑膠、金屬、玻璃 → 多半可回收。\n"
114
- "2. 沾滿油污、混合多種材質又不易拆解 → 通常當一般垃圾。\n"
115
- "3. 電器、電池、燈管、農藥容器等 → 應交由清潔隊或指定回收點。\n"
116
- "4. 若不確定,建議詢問當地環保局或 1999 專線。"
117
- )
118
 
119
- def search_qa(query: str):
120
  q_emb = embedder.encode(query, convert_to_tensor=True)
121
  scores = util.cos_sim(q_emb, qa_embeddings)[0]
122
- best_idx = torch.argmax(scores).item()
123
- best_score = float(scores[best_idx].item())
124
-
125
- if best_score > 0.7:
126
- return qas[best_idx]["answer"]
127
- else:
128
- return None
129
-
130
- # ========== 6. 主助理邏輯 ==========
131
- def waste_assistant(user_text, image):
132
- #圖片的情況(可以同時搭配文字)
133
- if image is not None:
134
- pil_image = Image.fromarray(image)
135
- idx, score = classify_image_with_clip(pil_image)
136
-
137
- # threshold:判斷「是否在 271 類的合理範圍內」
138
- THRESH = 0.25
139
- if score >= THRESH:
140
- item = id_to_item[idx]
141
- ans = build_recycle_answer(item, score)
142
-
143
- # 如果還有文字問題,就順便試著回答
144
- if user_text:
145
- qa_ans = search_qa(user_text)
146
- if qa_ans:
147
- ans += "\n\n---\n\n📚 相關延伸說明:\n" + qa_ans
148
- else:
149
- # 補上一個簡單 LLM 回覆(可註解)
150
- extra = llm_fallback(user_text)
151
- ans += "\n\n---\n\n🤖 額外說明(模型推論):\n" + extra
152
-
153
- return ans
154
- else:
155
- # score 太低:可能不在 271 類中
156
- base = generic_recycle_hint()
157
- if user_text:
158
- # 若有問題,就用 LLM 回答問題內容
159
- extra = llm_fallback(user_text)
160
- base += "\n\n---\n\n🤖 根據你輸入的文字,這是模型的推論:\n" + extra
161
- return base
162
-
163
- # 純文字問答模式
164
- if user_text:
165
- qa_ans = search_qa(user_text)
166
- if qa_ans:
167
- return qa_ans
168
- # 找不到就交給 LLM 硬推
169
- return llm_fallback(user_text)
170
-
171
- return "請上傳圖片或輸入問題。"
172
-
173
- # ========== 7. Gradio 介面 ==========
174
- demo = gr.Interface(
175
- fn=waste_assistant,
176
- inputs=[
177
- gr.Textbox(label="輸入你的問題(可留空,只傳圖片)"),
178
- gr.Image(type="numpy", label="上傳垃圾 / 物品的照片")
179
- ],
180
- outputs=gr.Markdown(),
181
- title="台南垃圾分類智慧助理(CLIP + 271 類回收資料)",
182
- description=(
183
- "● 上傳圖片,我會幫你猜這是什麼,並從回收資料中找最接近的物品,提供回收方式。\n"
184
- "● 可以同時輸入文字,例如「這個要怎麼回收?」或「這個是可回收嗎?」\n"
185
- "● 也可以只輸入文字,查詢常見的垃圾分類 / 回收問答。\n"
186
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
 
189
- if __name__ == "__main__":
190
- demo.launch()
 
3
  import torch
4
  from PIL import Image
5
  from sentence_transformers import SentenceTransformer, util
6
+ from transformers import (
7
+ CLIPProcessor, CLIPModel,
8
+ AutoTokenizer, AutoModelForCausalLM
9
+ )
10
 
11
+ # =======================================
12
+ # 1. Load recycle data
13
+ # =======================================
14
+ recycle_data = json.load(open("recycle_data.json", "r", encoding="utf-8"))
15
 
 
16
  label_texts = []
17
+ items = []
 
18
  for item in recycle_data:
19
+ zh = item.get("name", "")
20
+ en = item.get("english_name") or ""
21
+ label_texts.append(f"{en}, {zh}" if en else zh)
22
+ items.append(item)
23
+
24
+ # =======================================
25
+ # 2. Load Q&A data
26
+ # =======================================
27
+ qas = json.load(open("qas.json", "r", encoding="utf-8"))
 
 
 
 
 
 
 
28
  qa_questions = [q["question"] for q in qas]
29
 
30
  embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
31
  qa_embeddings = embedder.encode(qa_questions, convert_to_tensor=True)
32
 
33
+ # =======================================
34
+ # 3. Load CLIP for image → text similarity
35
+ # =======================================
36
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
37
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
38
 
 
39
  with torch.no_grad():
40
+ t_inputs = clip_processor(
41
+ text=label_texts, images=None, return_tensors="pt", padding=True
42
+ )
43
+ text_embeds = clip_model.get_text_features(
44
+ input_ids=t_inputs["input_ids"],
45
+ attention_mask=t_inputs["attention_mask"]
46
  )
 
47
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
48
 
49
+ # =======================================
50
+ # 4. SUPER-FAST Chat LLM (0.5B)
51
+ # =======================================
52
+ LLM_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
53
+
54
+ tok = AutoTokenizer.from_pretrained(LLM_NAME)
55
+ llm = AutoModelForCausalLM.from_pretrained(
56
+ LLM_NAME,
57
+ torch_dtype=torch.float32,
58
+ device_map="cpu"
59
+ )
60
+
61
+ def llm_chat(prompt):
62
+ inputs = tok(prompt, return_tensors="pt")
63
+ outputs = llm.generate(
64
+ **inputs,
65
+ max_new_tokens=120,
66
+ temperature=0.4
 
67
  )
68
+ return tok.decode(outputs[0], skip_special_tokens=True)
69
 
70
+ # =======================================
71
+ # Helper functions
72
+ # =======================================
73
+ def classify_image(image):
74
+ inputs = clip_processor(images=image, return_tensors="pt")
75
  with torch.no_grad():
76
+ img_emb = clip_model.get_image_features(**inputs)
77
+ img_emb = img_emb / img_emb.norm(p=2, dim=-1, keepdim=True)
78
+ logits = img_emb @ text_embeds.T
 
 
79
  probs = logits.softmax(dim=-1)[0]
 
80
  score, idx = torch.max(probs, dim=-1)
81
+ return idx.item(), float(score.item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ def search_qa(query):
84
  q_emb = embedder.encode(query, convert_to_tensor=True)
85
  scores = util.cos_sim(q_emb, qa_embeddings)[0]
86
+ idx = torch.argmax(scores).item()
87
+ if scores[idx] > 0.70:
88
+ return qas[idx]["answer"]
89
+ return None
90
+
91
+ def general_rules():
92
+ return (
93
+ "以下是一般垃圾分類原則:\n"
94
+ "1. 乾淨可分離材質 → 可回收。\n"
95
+ "2. 污損/混合材質不易拆 → 一般垃圾。\n"
96
+ "3. 電器、電池、害物 → 指定回收。\n"
97
+ "4. 不確定時 1999 或問清潔隊。\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
99
+
100
+ # =======================================
101
+ # 5. Main Chatbot Logic
102
+ # =======================================
103
+ def chatbot(message, history):
104
+ image = None
105
+ if isinstance(message, dict) and "image" in message:
106
+ image = message["image"]
107
+ message = ""
108
+
109
+ final_answer = ""
110
+
111
+ # --- Image mode ---
112
+ if image:
113
+ pil = Image.fromarray(image)
114
+ idx, sim = classify_image(pil)
115
+
116
+ if sim >= 0.25:
117
+ item = items[idx]
118
+ final_answer += (
119
+ f"🔍 推測最接近:**{item['name']}**(相似度 {sim:.2f})\n\n"
120
+ f"♻️ {item.get('recyclable', '')}\n\n"
121
+ f"{item.get('notes', '')}\n\n"
122
+ )
123
+ else:
124
+ final_answer += (
125
+ "❓ 無法確定圖片屬於資料庫中的哪一項物品。\n\n" +
126
+ general_rules()
127
+ )
128
+
129
+ # --- Text mode ---
130
+ if message:
131
+ q_ans = search_qa(message)
132
+ if q_ans:
133
+ final_answer += f"📘 查到官方資料:\n{q_ans}\n"
134
+ else:
135
+ llm_ans = llm_chat(f"請以台灣垃圾分類規則回答問題:{message}")
136
+ final_answer += f"🤖 推論回答:\n{llm_ans}\n"
137
+
138
+ return final_answer or "請輸入問題或上傳圖片。"
139
+
140
+ # =======================================
141
+ # 6. Chat UI
142
+ # =======================================
143
+ chat_ui = gr.ChatInterface(
144
+ fn=chatbot,
145
+ title="垃圾分類聊天助理(CLIP × Qwen × 271 類)",
146
+ description="可上傳圖片,也可直接聊天。"
147
  )
148
 
149
+ chat_ui.launch()