pixel3user commited on
Commit
7a99397
·
1 Parent(s): 2cc645f

added json

Browse files
Files changed (1) hide show
  1. app.py +103 -115
app.py CHANGED
@@ -70,7 +70,7 @@ def format_candidates_for_llm(cands, budget_twd=None):
70
  "image_url": c.get("image_url"),
71
  "score": c.get("score"),
72
  })
73
- return json.dumps(filtered, ensure_ascii=False, indent=2)
74
 
75
  DERMA_SAFETY = (
76
  "Safety notes: For broken/infected skin, pregnancy/lactation, infants, "
@@ -82,7 +82,7 @@ def recommend_products(query_text: str, budget_twd: int | None = None, k: int =
82
  cands = product_search(query_text, k=k)
83
 
84
  # 2) Build short grounded context
85
- context = format_candidates_for_llm(cands, budget_twd=budget_twd)
86
 
87
  # 3) Ask your LLM to pick & explain (plug into your existing generation path)
88
  system = (
@@ -90,119 +90,34 @@ def recommend_products(query_text: str, budget_twd: int | None = None, k: int =
90
  "from the provided list. Include a one-line why-it-helps and a brief how-to-use. "
91
  "Respect budget and do not invent products."
92
  )
93
- user = f"User need: {query_text}\nCandidate products (JSON array):\n{context}\n{DERMA_SAFETY}"
94
 
95
  # --- if you already have Qwen2-VL loaded as text generator, reuse it.
96
  # Example skeleton (pseudo—replace with your app’s generate() function):
97
  try:
98
  # Replace this with your actual text-generation helper:
99
- answer = f"(LLM picks here)\n\nContext:\n{context}"
100
  except Exception as e:
101
- answer = f"❌ Generation error: {e}\n\nHere are candidates:\n{context}"
102
 
103
  return answer
104
 
105
 
106
- def _count_products_from_tagged_json(text: str) -> int | None:
107
- start_tag = "<DERMACARE_PRODUCTS_JSON>"
108
- end_tag = "</DERMACARE_PRODUCTS_JSON>"
109
- start = text.find(start_tag)
110
- end = text.find(end_tag)
111
- if start == -1 or end == -1 or end <= start:
112
- return None
113
- json_str = text[start + len(start_tag):end]
114
- try:
115
- payload = json.loads(json_str)
116
- products = payload.get("products", [])
117
- if isinstance(products, list):
118
- return len(products)
119
- except Exception:
120
  return None
121
- return None
122
-
123
-
124
- def _has_valid_suggestions(raw_suggestions: str) -> bool:
125
- cleaned = raw_suggestions.strip()
126
- if not cleaned:
127
- return False
128
-
129
- lower = cleaned.lower()
130
- normalized = " ".join(lower.split())
131
- if "no relevant products" in normalized:
132
- return False
133
-
134
- count = _count_products_from_tagged_json(cleaned)
135
- if count is not None:
136
- return count > 0
137
-
138
- return True
139
-
140
- # ---- JSON block helpers ----
141
- def _extract_products_json_block(text: str) -> str | None:
142
- start_tag = "<DERMACARE_PRODUCTS_JSON>"
143
- end_tag = "</DERMACARE_PRODUCTS_JSON>"
144
- start = text.find(start_tag)
145
- end = text.find(end_tag)
146
  if start == -1 or end == -1 or end <= start:
147
  return None
148
- json_str = text[start + len(start_tag):end]
149
  try:
150
- payload = json.loads(json_str)
151
- products = payload.get("products", [])
152
- if isinstance(products, list) and len(products) > 0:
153
- return f"{start_tag}{json.dumps(payload, ensure_ascii=False)}{end_tag}"
154
  except Exception:
155
  return None
156
- return None
157
-
158
-
159
- def _filter_candidates_by_budget(cands: list[dict], budget_twd: int | None) -> list[dict]:
160
- if not budget_twd:
161
- return cands
162
- filtered: list[dict] = []
163
- for c in cands:
164
- currency = c.get("price_currency")
165
- value = c.get("price_value")
166
- if currency == "TWD" and isinstance(value, (int, float)) and value is not None:
167
- if value <= budget_twd:
168
- filtered.append(c)
169
- else:
170
- filtered.append(c)
171
- return filtered
172
-
173
-
174
- def _build_products_json_block_from_candidates(cands: list[dict], max_items: int = 3) -> str | None:
175
- if not cands:
176
- return None
177
- ranked = sorted(cands, key=lambda x: x.get("_score", 0.0), reverse=True)
178
- picked = ranked[:max_items]
179
- products = []
180
- for c in picked:
181
- products.append({
182
- "id": c.get("id"),
183
- "brand": c.get("brand_en") or c.get("brand_zh"),
184
- "name": c.get("product_name_en") or c.get("product_name_zh"),
185
- "category": c.get("category_en") or c.get("category_zh"),
186
- "price_value": c.get("price_value"),
187
- "price_currency": c.get("price_currency"),
188
- "why": None,
189
- "how": None,
190
- "url": c.get("source_url"),
191
- "image_url": c.get("image_url"),
192
- })
193
- valid_count = sum(1 for p in products if p.get("name"))
194
- if valid_count == 0:
195
- return None
196
- payload = {"version": 1, "products": products}
197
- return f"<DERMACARE_PRODUCTS_JSON>{json.dumps(payload, ensure_ascii=False)}</DERMACARE_PRODUCTS_JSON>"
198
-
199
-
200
- def _ensure_products_json_block(suggestions: str, cands: list[dict], budget_twd: int | None) -> str | None:
201
- existing = _extract_products_json_block(suggestions)
202
- if existing:
203
- return existing
204
- filtered = _filter_candidates_by_budget(cands, budget_twd)
205
- return _build_products_json_block_from_candidates(filtered)
206
 
207
  # ---- Inference on GPU (ZeroGPU pattern) ----
208
  @spaces.GPU(duration=120)
@@ -311,25 +226,98 @@ def pet_answer_with_recs(image, question, temperature=0.7, top_p=0.95, max_token
311
  top_p=0.95,
312
  )
313
  trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out)]
314
- suggestions = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
 
 
 
315
 
316
- # Final combined message
317
- safety = (
318
- "Safety notes: For broken/infected skin, pregnancy/lactation, infants, "
319
- "or if symptoms worsen—seek a qualified dermatologist. Patch-test first."
320
- )
321
- suggestions = suggestions.strip()
322
- include_products = _has_valid_suggestions(suggestions)
323
 
324
  sections = [base.strip()]
325
- if include_products:
326
- sections.append(f"Suggested products:\n{suggestions}")
327
- if "<DERMACARE_PRODUCTS_JSON>" not in suggestions:
328
- json_block = _ensure_products_json_block(suggestions, cands, budget_twd)
329
- if json_block:
330
- sections.append(json_block)
331
- sections.append(safety)
332
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  return "\n\n".join([s for s in sections if s])
334
 
335
  # ---- UI ----
 
70
  "image_url": c.get("image_url"),
71
  "score": c.get("score"),
72
  })
73
+ return json.dumps(filtered, ensure_ascii=False, indent=2), filtered
74
 
75
  DERMA_SAFETY = (
76
  "Safety notes: For broken/infected skin, pregnancy/lactation, infants, "
 
82
  cands = product_search(query_text, k=k)
83
 
84
  # 2) Build short grounded context
85
+ context_json, _ = format_candidates_for_llm(cands, budget_twd=budget_twd)
86
 
87
  # 3) Ask your LLM to pick & explain (plug into your existing generation path)
88
  system = (
 
90
  "from the provided list. Include a one-line why-it-helps and a brief how-to-use. "
91
  "Respect budget and do not invent products."
92
  )
93
+ user = f"User need: {query_text}\nCandidate products (JSON array):\n{context_json}\n{DERMA_SAFETY}"
94
 
95
  # --- if you already have Qwen2-VL loaded as text generator, reuse it.
96
  # Example skeleton (pseudo—replace with your app’s generate() function):
97
  try:
98
  # Replace this with your actual text-generation helper:
99
+ answer = f"(LLM picks here)\n\nContext:\n{context_json}"
100
  except Exception as e:
101
+ answer = f"❌ Generation error: {e}\n\nHere are candidates:\n{context_json}"
102
 
103
  return answer
104
 
105
 
106
+ def _parse_recommendation_json(raw: str):
107
+ if not raw:
 
 
 
 
 
 
 
 
 
 
 
 
108
  return None
109
+ cleaned = raw.strip()
110
+ if cleaned.startswith("```"):
111
+ lines = [line for line in cleaned.splitlines() if not line.strip().startswith("```")]
112
+ cleaned = "\n".join(lines)
113
+ start = cleaned.find('{')
114
+ end = cleaned.rfind('}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if start == -1 or end == -1 or end <= start:
116
  return None
 
117
  try:
118
+ return json.loads(cleaned[start:end + 1])
 
 
 
119
  except Exception:
120
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # ---- Inference on GPU (ZeroGPU pattern) ----
123
  @spaces.GPU(duration=120)
 
226
  top_p=0.95,
227
  )
228
  trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out)]
229
+ raw_response = processor.batch_decode(
230
+ trimmed,
231
+ skip_special_tokens=True,
232
+ clean_up_tokenization_spaces=False,
233
+ )[0]
234
 
235
+ rec_data = _parse_recommendation_json(raw_response)
 
 
 
 
 
 
236
 
237
  sections = [base.strip()]
238
+ suggestion_text = None
239
+ product_json_payload = None
240
+
241
+ if rec_data:
242
+ recommend_flag = rec_data.get("recommend")
243
+ if isinstance(recommend_flag, str):
244
+ recommend_flag = recommend_flag.strip().lower() in {"yes", "true", "1"}
245
+ elif isinstance(recommend_flag, (int, float)):
246
+ recommend_flag = bool(recommend_flag)
247
+
248
+ recs = []
249
+ for item in rec_data.get("recommendations", []):
250
+ if isinstance(item, dict) and item.get("id"):
251
+ recs.append(item)
252
+
253
+ if recommend_flag and recs:
254
+ suggestion_lines = ["### Suggested Products", ""]
255
+ products_payload = []
256
+
257
+ for idx, rec in enumerate(recs[:3], start=1):
258
+ pid = rec.get("id")
259
+ candidate = candidate_lookup.get(pid, {})
260
+
261
+ brand = (
262
+ candidate.get("brand_en")
263
+ or candidate.get("brand_zh")
264
+ or rec.get("brand")
265
+ or ""
266
+ )
267
+ name = (
268
+ candidate.get("product_name_en")
269
+ or candidate.get("product_name_zh")
270
+ or rec.get("name")
271
+ or f"Product {idx}"
272
+ )
273
+ category = (
274
+ candidate.get("category_en")
275
+ or candidate.get("category_zh")
276
+ or rec.get("category")
277
+ or None
278
+ )
279
+ price_value = candidate.get("price_value")
280
+ price_currency = candidate.get("price_currency")
281
+ why = rec.get("why") or "Supports the user’s concern."
282
+ how = rec.get("how") or "Use as directed on the product label."
283
+ url = candidate.get("source_url") or rec.get("url")
284
+ image_url = candidate.get("image_url") or rec.get("image_url")
285
+
286
+ suggestion_lines.extend([
287
+ f"{idx}. **{name}**",
288
+ f"- **Why it helps:** {why}",
289
+ f"- **How to use:** {how}",
290
+ "",
291
+ ])
292
+
293
+ products_payload.append({
294
+ "id": pid,
295
+ "brand": brand,
296
+ "name": name,
297
+ "category": category,
298
+ "price_value": price_value,
299
+ "price_currency": price_currency,
300
+ "why": why,
301
+ "how": how,
302
+ "url": url,
303
+ "image_url": image_url,
304
+ })
305
+
306
+ if products_payload:
307
+ suggestion_text = "\n".join(suggestion_lines).strip()
308
+ product_json_payload = json.dumps(
309
+ {"version": 1, "products": products_payload},
310
+ ensure_ascii=False,
311
+ )
312
+
313
+ if suggestion_text and product_json_payload:
314
+ sections.append(
315
+ "Suggested products:\n"
316
+ f"{suggestion_text}\n\n"
317
+ f"<DERMACARE_PRODUCTS_JSON>{product_json_payload}</DERMACARE_PRODUCTS_JSON>"
318
+ )
319
+ sections.append(DERMA_SAFETY)
320
+
321
  return "\n\n".join([s for s in sections if s])
322
 
323
  # ---- UI ----