File size: 4,243 Bytes
4e80910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
ADAPTER_RL_PATH = os.environ.get("ADAPTER_RL_PATH", "./adapter_dpo")  # commit folder into Space repo

SYSTEM_PROMPT = """You are a retail recommendation assistant.
You recommend at most 3 items that complement the user's cart and intent.
You must be:
- Relevant to the cart + intent
- Constraint-aware (budget, urgency, compatibility, brand preferences)
- Non-pushy and honest (no made-up specs or guarantees)
- Concise and structured

Output format:
Recommendations:
1) <item> — <one-line reason>
2) ...
Why these:
- ...
Compatibility / checks:
- ...
Optional next step:
- (only if helpful)
"""

def build_prompt(user_intent, cart, budget, urgency, brand_avoid):
    cart_list = [c.strip() for c in cart.split(",") if c.strip()]
    constraints = {
        "budget_usd": budget,
        "shipping_urgency": urgency,
        "brand_avoid": [b.strip() for b in brand_avoid.split(",") if b.strip()]
    }
    user_block = (
        f"User intent: {user_intent}\n"
        f"Cart: {', '.join(cart_list)}\n"
        f"Constraints: {constraints}\n"
        "Generate recommendations following the required format."
    )
    return f"<|system|>\n{SYSTEM_PROMPT}\n<|user|>\n{user_block}\n<|assistant|>\n"

@torch.inference_mode()
def generate(model, tok, prompt, max_new_tokens=220, temperature=0.7):
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        pad_token_id=tok.eos_token_id,
    )
    text = tok.decode(out[0], skip_special_tokens=True)

    # Best-effort: return only assistant completion
    if "<|assistant|>" in text:
        return text.split("<|assistant|>", 1)[-1].strip()
    return text.strip()

def load_models():
    tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        device_map="auto",
        torch_dtype="auto",
    )

    rl = None
    if os.path.exists(ADAPTER_RL_PATH):
        rl = PeftModel.from_pretrained(base, ADAPTER_RL_PATH)
    return tok, base, rl

tok, base_model, rl_model = load_models()

def run_both(user_intent, cart, budget, urgency, brand_avoid, max_new_tokens, temperature):
    prompt = build_prompt(user_intent, cart, budget, urgency, brand_avoid)
    before = generate(base_model, tok, prompt, max_new_tokens=max_new_tokens, temperature=temperature)

    if rl_model is None:
        after = "RL adapter not found. Ensure adapter_dpo/ is included in the Space repo or ADAPTER_RL_PATH is correct."
    else:
        after = generate(rl_model, tok, prompt, max_new_tokens=max_new_tokens, temperature=temperature)

    return before, after

with gr.Blocks() as demo:
    gr.Markdown("# Retail Recommendation Explainer — Before vs After RL (DPO)")

    user_intent = gr.Textbox(
        label="User intent",
        value="I’m starting to run regularly and want to avoid blisters."
    )
    cart = gr.Textbox(label="Cart items (comma-separated)", value="running shoes, socks")
    with gr.Row():
        budget = gr.Slider(10, 150, value=40, step=5, label="Budget (USD)")
        urgency = gr.Dropdown(["fast", "normal"], value="fast", label="Shipping urgency")
    brand_avoid = gr.Textbox(label="Brands/materials to avoid (comma-separated)", value="")

    with gr.Row():
        max_new_tokens = gr.Slider(80, 400, value=220, step=10, label="Max new tokens")
        temperature = gr.Slider(0.0, 1.2, value=0.7, step=0.05, label="Temperature")

    btn = gr.Button("Generate (Before vs After)")
    with gr.Row():
        out_before = gr.Textbox(label="Before (Base)", lines=18)
        out_after = gr.Textbox(label="After (RL / DPO LoRA)", lines=18)

    btn.click(
        fn=run_both,
        inputs=[user_intent, cart, budget, urgency, brand_avoid, max_new_tokens, temperature],
        outputs=[out_before, out_after]
    )

demo.launch()