File size: 15,640 Bytes
c6d6bec
5e22a9c
c6d6bec
c78b365
dad83c8
9c651dd
5e22a9c
c78b365
3e46c4a
13a59c9
5e22a9c
13a59c9
5e22a9c
3e46c4a
a1e42ba
 
c78b365
e39687f
5e22a9c
13a59c9
5e22a9c
c78b365
 
9c651dd
c6d6bec
5e22a9c
13a59c9
5789e4d
5e22a9c
 
e39687f
9c651dd
c78b365
 
 
 
5789e4d
 
5e22a9c
e39687f
 
 
 
c78b365
e39687f
c6d6bec
5e22a9c
c78b365
3e46c4a
4b3e157
3e46c4a
4b3e157
 
 
 
 
 
3e46c4a
4b3e157
 
 
 
 
 
 
 
 
 
 
 
 
 
7a99397
3e46c4a
 
 
 
 
 
 
 
 
 
 
7a99397
3e46c4a
 
 
 
 
 
 
7a99397
3e46c4a
 
 
 
 
7a99397
3e46c4a
7a99397
3e46c4a
 
 
b340e7a
7a99397
 
b340e7a
7a99397
 
 
 
 
 
2cc645f
 
 
7a99397
2cc645f
 
 
103cddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e22a9c
9c651dd
5e22a9c
 
 
 
c78b365
13a59c9
 
c6d6bec
c78b365
 
 
 
5e22a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c78b365
13a59c9
5e22a9c
9c651dd
 
 
 
 
 
5e22a9c
 
 
c6d6bec
21f5009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103cddb
 
 
 
21f5009
 
 
 
 
 
 
 
103cddb
 
 
 
21f5009
 
 
 
 
103cddb
21f5009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a99397
 
 
 
 
21f5009
7a99397
b340e7a
 
7a99397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b340e7a
21f5009
5e22a9c
3e46c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21f5009
 
 
 
 
c6d6bec
b340e7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
"""
    PetBull‑7B‑VL demo – ZeroGPU‑ready (Qwen2.5‑VL API)
"""
import os
import json
import spaces
import torch
import gradio as gr
import transformers, accelerate, numpy as np
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from peft import PeftModel
from qwen_vl_utils import process_vision_info  # pip install qwen-vl-utils
from tools.retriever import search as product_search

print("VERSIONS:", transformers.__version__, accelerate.__version__, torch.__version__, np.__version__)
os.environ["ACCELERATE_USE_SLOW_RETRIEVAL"] = "true"

# ---- Config ----
BASE_MODEL   = "Qwen/Qwen2.5-VL-7B-Instruct"
ADAPTER_REPO = "ColdSlim/PetBull-7B"  # your LoRA
ADAPTER_REV  = "master"
OFFLOAD_DIR  = "offload"
DTYPE        = torch.float16

# ---- Processor (no GPU) ----
processor = AutoProcessor.from_pretrained(BASE_MODEL, trust_remote_code=True)

# ---- Base model ON CPU (do NOT touch CUDA here) ----
base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    torch_dtype=DTYPE,
    low_cpu_mem_usage=True,
    device_map={"": "cpu"},
    offload_folder=OFFLOAD_DIR,
    trust_remote_code=True,
)

# ---- Attach LoRA ON CPU ----
model = PeftModel.from_pretrained(
    base,
    ADAPTER_REPO,
    revision=ADAPTER_REV,
    device_map={"": "cpu"},
).eval()

_model_on_gpu = False  # once-per-session move

def format_candidates_for_llm(cands, budget_twd=None):
    filtered = []
    for c in cands:
        if (
            budget_twd
            and c.get("price_currency") == "TWD"
            and c.get("price_value")
            and c["price_value"] > budget_twd
        ):
            continue
        filtered.append({
            "id": c.get("id"),
            "brand_en": c.get("brand_en"),
            "brand_zh": c.get("brand_zh"),
            "product_name_en": c.get("product_name_en"),
            "product_name_zh": c.get("product_name_zh"),
            "category_en": c.get("category_en"),
            "category_zh": c.get("category_zh"),
            "price_value": c.get("price_value"),
            "price_currency": c.get("price_currency"),
            "source_url": c.get("source_url"),
            "image_url": c.get("image_url"),
            "score": c.get("score"),
        })
    return json.dumps(filtered, ensure_ascii=False, indent=2), filtered

DERMA_SAFETY = (
    "Safety notes: For broken/infected skin, pregnancy/lactation, infants, "
    "or if symptoms worsen—seek a qualified dermatologist. Patch-test first."
)

def recommend_products(query_text: str, budget_twd: int | None = None, k: int = 8):
    # 1) Retrieve candidates
    cands = product_search(query_text, k=k)

    # 2) Build short grounded context
    context_json, _ = format_candidates_for_llm(cands, budget_twd=budget_twd)

    # 3) Ask your LLM to pick & explain (plug into your existing generation path)
    system = (
        "You are DermalCare’s assistant. Recommend up to 3 products strictly "
        "from the provided list. Include a one-line why-it-helps and a brief how-to-use. "
        "Respect budget and do not invent products."
    )
    user = f"User need: {query_text}\nCandidate products (JSON array):\n{context_json}\n{DERMA_SAFETY}"

    # --- if you already have Qwen2-VL loaded as text generator, reuse it.
    # Example skeleton (pseudo—replace with your app’s generate() function):
    try:
        # Replace this with your actual text-generation helper:
        answer = f"(LLM picks here)\n\nContext:\n{context_json}"
    except Exception as e:
        answer = f"❌ Generation error: {e}\n\nHere are candidates:\n{context_json}"

    return answer


def _parse_recommendation_json(raw: str):
    if not raw:
        return None
    cleaned = raw.strip()
    if cleaned.startswith("```"):
        lines = [line for line in cleaned.splitlines() if not line.strip().startswith("```")]
        cleaned = "\n".join(lines)
    start = cleaned.find('{')
    end = cleaned.rfind('}')
    if start == -1 or end == -1 or end <= start:
        return None
    try:
        return json.loads(cleaned[start:end + 1])
    except Exception:
        return None


def _build_recommendation_sections(rec_data, candidate_lookup):
    if not rec_data:
        return None, None

    recommend_flag = rec_data.get("recommend")
    if isinstance(recommend_flag, str):
        recommend_flag = recommend_flag.strip().lower() in {"yes", "true", "1"}
    elif isinstance(recommend_flag, (int, float)):
        recommend_flag = bool(recommend_flag)

    if not recommend_flag:
        return None, None

    recommendations = rec_data.get("recommendations", [])
    if not isinstance(recommendations, list):
        return None, None

    lines = ["### Suggested Products", ""]
    products_payload = []

    for idx, item in enumerate(recommendations[:3], start=1):
        if not isinstance(item, dict):
            continue
        raw_id = item.get("id")
        if raw_id is None:
            continue
        pid = str(raw_id).strip()
        if not pid:
            continue

        candidate = candidate_lookup.get(pid, {})

        brand = (
            candidate.get("brand_en")
            or candidate.get("brand_zh")
            or item.get("brand")
            or ""
        )
        name = (
            candidate.get("product_name_en")
            or candidate.get("product_name_zh")
            or item.get("name")
            or f"Product {idx}"
        )
        category = (
            candidate.get("category_en")
            or candidate.get("category_zh")
            or item.get("category")
            or None
        )
        price_value = candidate.get("price_value")
        price_currency = candidate.get("price_currency")
        why = item.get("why") or "Supports the user’s concern."
        how = item.get("how") or "Use as directed on the product label."
        url = candidate.get("source_url") or item.get("url")
        image_url = candidate.get("image_url") or item.get("image_url")

        lines.extend([
            f"{idx}. **{name}**",
            f"- **Why it helps:** {why}",
            f"- **How to use:** {how}",
            "",
        ])

        products_payload.append({
            "id": pid,
            "brand": brand,
            "name": name,
            "category": category,
            "price_value": price_value,
            "price_currency": price_currency,
            "why": why,
            "how": how,
            "url": url,
            "image_url": image_url,
        })

    if not products_payload:
        return None, None

    suggestion_text = "\n".join(lines).strip()
    product_json_payload = json.dumps(
        {"version": 1, "products": products_payload},
        ensure_ascii=False,
    )
    return suggestion_text, product_json_payload

# ---- Inference on GPU (ZeroGPU pattern) ----
@spaces.GPU(duration=120)
def generate_answer(image, question, temperature=0.7, top_p=0.95, max_tokens=256):
    """
    Uses Qwen2.5-VL chat template + qwen_vl_utils to prepare image+text, then generate.
    """
    global _model_on_gpu
    if image is None:
        image = Image.new("RGB", (224, 224), color="white")

    if not _model_on_gpu:
        model.to("cuda")
        _model_on_gpu = True

    # Build chat messages in Qwen format
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text",  "text": question or "Describe this image."},
        ],
    }]

    # Processor helpers
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    # Pack tensors on GPU
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = {k: (v.to("cuda") if hasattr(v, "to") else v) for k, v in inputs.items()}

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )

    # Trim prompt tokens before decode (Qwen style)
    trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out)]
    return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

# ---- PetCare answer + product suggestions (ONE output) ----
@spaces.GPU(duration=120)
def pet_answer_with_recs(image, question, temperature=0.7, top_p=0.95, max_tokens=256, budget_twd=None):
    """
    1) Get the normal PetBull answer (image + text).
    2) Run vector search on the user's question.
    3) Ask the LLM (text-only) to decide if any candidates are relevant for the user's issue.
       If yes, append a 'Suggested products' section (up to 3 items from the list).
       If not, append 'No relevant products.'.
    """
    # Step 1: normal PetBull answer
    base = generate_answer(image, question, temperature, top_p, max_tokens)

    # Step 2: retrieve product candidates (humans/skincare; model will decide relevance)
    cands = product_search(question, k=8)
    cand_block_json, cand_list = format_candidates_for_llm(cands, budget_twd=budget_twd)
    candidate_lookup = {
        str(c.get("id")).strip(): c for c in cand_list if c.get("id") is not None
    }

    # Step 3: build a small, text-only prompt for suggestions
    # IMPORTANT: we use the same Qwen2.5-VL model in text mode
    messages = [{
        "role": "user",
        "content": [
            {"type": "text", "text":
             "You are DermalCare's assistant.\n"
             "Respond ONLY with valid JSON (no markdown, no explanations).\n"
             "Expected schema: {\"recommend\": bool, \"recommendations\": [ {\"id\": str, \"why\": str, \"how\": str } ], \"notes\": str }.\n"
             "Use candidate_products as the exclusive source of items. If a product is recommended, its id must exist in candidate_products.\n"
             "If no products are relevant, return {\"recommend\": false, \"recommendations\": [], \"notes\": \"No relevant products.\"}."}
        ]
    },{
        "role": "user",
        "content": [
            {"type": "text", "text": f"User message:\n{question}"},
            {"type": "text", "text": f"candidate_products = {cand_block_json}"}
        ]
    }]

    # Prepare inputs on GPU (text-only)
    text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[text_prompt], images=None, videos=None, padding=True, return_tensors="pt",
    )
    inputs = {k: (v.to("cuda") if hasattr(v, "to") else v) for k, v in inputs.items()}

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.2,      # keep precise/grounded
            top_p=0.95,
        )
    trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out)]
    raw_response = processor.batch_decode(
        trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )[0]

    rec_data = _parse_recommendation_json(raw_response)

    sections = [base.strip()]
    suggestion_text = None
    product_json_payload = None

    if rec_data:
        recommend_flag = rec_data.get("recommend")
        if isinstance(recommend_flag, str):
            recommend_flag = recommend_flag.strip().lower() in {"yes", "true", "1"}
        elif isinstance(recommend_flag, (int, float)):
            recommend_flag = bool(recommend_flag)

        recs = []
        for item in rec_data.get("recommendations", []):
            if isinstance(item, dict) and item.get("id"):
                recs.append(item)

        if recommend_flag and recs:
            suggestion_lines = ["### Suggested Products", ""]
            products_payload = []

            for idx, rec in enumerate(recs[:3], start=1):
                pid = rec.get("id")
                candidate = candidate_lookup.get(pid, {})

                brand = (
                    candidate.get("brand_en")
                    or candidate.get("brand_zh")
                    or rec.get("brand")
                    or ""
                )
                name = (
                    candidate.get("product_name_en")
                    or candidate.get("product_name_zh")
                    or rec.get("name")
                    or f"Product {idx}"
                )
                category = (
                    candidate.get("category_en")
                    or candidate.get("category_zh")
                    or rec.get("category")
                    or None
                )
                price_value = candidate.get("price_value")
                price_currency = candidate.get("price_currency")
                why = rec.get("why") or "Supports the user’s concern."
                how = rec.get("how") or "Use as directed on the product label."
                url = candidate.get("source_url") or rec.get("url")
                image_url = candidate.get("image_url") or rec.get("image_url")

                suggestion_lines.extend([
                    f"{idx}. **{name}**",
                    f"- **Why it helps:** {why}",
                    f"- **How to use:** {how}",
                    "",
                ])

                products_payload.append({
                    "id": pid,
                    "brand": brand,
                    "name": name,
                    "category": category,
                    "price_value": price_value,
                    "price_currency": price_currency,
                    "why": why,
                    "how": how,
                    "url": url,
                    "image_url": image_url,
                })

            if products_payload:
                suggestion_text = "\n".join(suggestion_lines).strip()
                product_json_payload = json.dumps(
                    {"version": 1, "products": products_payload},
                    ensure_ascii=False,
                )

    if suggestion_text and product_json_payload:
        sections.append(
            "Suggested products:\n"
            f"{suggestion_text}\n\n"
            f"<DERMACARE_PRODUCTS_JSON>{product_json_payload}</DERMACARE_PRODUCTS_JSON>"
        )
        sections.append(DERMA_SAFETY)

    return "\n\n".join([s for s in sections if s])

# ---- UI ----
with gr.Blocks(title="DermalCare - Pet & Skincare Assistant") as demo:
    gr.Markdown("# DermalCare - Your AI Assistant for Pet Care and Skincare")
    
    with gr.Tabs():
        with gr.Tab("Pet Care"):
            gr.Markdown("## PetBull‑7B‑VL – Ask a Vet\nUpload a photo and/or type a question.")
            with gr.Row():
                with gr.Column():
                    img_in  = gr.Image(type="pil", label="Pet photo (optional)")
                    txt_in  = gr.Textbox(lines=3, placeholder="Describe the issue…")
                    ask     = gr.Button("Ask PetBull")
                    temp    = gr.Slider(0.1, 1.5, 0.7, label="Temperature")
                    topp    = gr.Slider(0.1, 1.0, 0.95, label="Top‑p")
                    max_tok = gr.Slider(32, 512, 256, step=8, label="Max tokens")
                with gr.Column():
                    answer = gr.Textbox(lines=12, label="Assistant", interactive=False)

            ask.click(
                pet_answer_with_recs,
                inputs=[img_in, txt_in, temp, topp, max_tok],   # (budget optional: add at the end if you want)
                outputs=answer
            )

demo.queue().launch(show_api=False, share=True)