|
|
import os |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
|
|
|
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") |
|
|
ADAPTER_RL_PATH = os.environ.get("ADAPTER_RL_PATH", "./adapter_dpo") |
|
|
|
|
|
SYSTEM_PROMPT = """You are a retail recommendation assistant. |
|
|
You recommend at most 3 items that complement the user's cart and intent. |
|
|
You must be: |
|
|
- Relevant to the cart + intent |
|
|
- Constraint-aware (budget, urgency, compatibility, brand preferences) |
|
|
- Non-pushy and honest (no made-up specs or guarantees) |
|
|
- Concise and structured |
|
|
|
|
|
Output format: |
|
|
Recommendations: |
|
|
1) <item> — <one-line reason> |
|
|
2) ... |
|
|
Why these: |
|
|
- ... |
|
|
Compatibility / checks: |
|
|
- ... |
|
|
Optional next step: |
|
|
- (only if helpful) |
|
|
""" |
|
|
|
|
|
def build_prompt(user_intent, cart, budget, urgency, brand_avoid): |
|
|
cart_list = [c.strip() for c in cart.split(",") if c.strip()] |
|
|
constraints = { |
|
|
"budget_usd": budget, |
|
|
"shipping_urgency": urgency, |
|
|
"brand_avoid": [b.strip() for b in brand_avoid.split(",") if b.strip()] |
|
|
} |
|
|
user_block = ( |
|
|
f"User intent: {user_intent}\n" |
|
|
f"Cart: {', '.join(cart_list)}\n" |
|
|
f"Constraints: {constraints}\n" |
|
|
"Generate recommendations following the required format." |
|
|
) |
|
|
return f"<|system|>\n{SYSTEM_PROMPT}\n<|user|>\n{user_block}\n<|assistant|>\n" |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate(model, tok, prompt, max_new_tokens=220, temperature=0.7): |
|
|
inputs = tok(prompt, return_tensors="pt").to(model.device) |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True if temperature > 0 else False, |
|
|
temperature=temperature, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
) |
|
|
text = tok.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "<|assistant|>" in text: |
|
|
return text.split("<|assistant|>", 1)[-1].strip() |
|
|
return text.strip() |
|
|
|
|
|
def load_models(): |
|
|
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL, |
|
|
device_map="auto", |
|
|
torch_dtype="auto", |
|
|
) |
|
|
|
|
|
rl = None |
|
|
if os.path.exists(ADAPTER_RL_PATH): |
|
|
rl = PeftModel.from_pretrained(base, ADAPTER_RL_PATH) |
|
|
return tok, base, rl |
|
|
|
|
|
tok, base_model, rl_model = load_models() |
|
|
|
|
|
def run_both(user_intent, cart, budget, urgency, brand_avoid, max_new_tokens, temperature): |
|
|
prompt = build_prompt(user_intent, cart, budget, urgency, brand_avoid) |
|
|
before = generate(base_model, tok, prompt, max_new_tokens=max_new_tokens, temperature=temperature) |
|
|
|
|
|
if rl_model is None: |
|
|
after = "RL adapter not found. Ensure adapter_dpo/ is included in the Space repo or ADAPTER_RL_PATH is correct." |
|
|
else: |
|
|
after = generate(rl_model, tok, prompt, max_new_tokens=max_new_tokens, temperature=temperature) |
|
|
|
|
|
return before, after |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Retail Recommendation Explainer — Before vs After RL (DPO)") |
|
|
|
|
|
user_intent = gr.Textbox( |
|
|
label="User intent", |
|
|
value="I’m starting to run regularly and want to avoid blisters." |
|
|
) |
|
|
cart = gr.Textbox(label="Cart items (comma-separated)", value="running shoes, socks") |
|
|
with gr.Row(): |
|
|
budget = gr.Slider(10, 150, value=40, step=5, label="Budget (USD)") |
|
|
urgency = gr.Dropdown(["fast", "normal"], value="fast", label="Shipping urgency") |
|
|
brand_avoid = gr.Textbox(label="Brands/materials to avoid (comma-separated)", value="") |
|
|
|
|
|
with gr.Row(): |
|
|
max_new_tokens = gr.Slider(80, 400, value=220, step=10, label="Max new tokens") |
|
|
temperature = gr.Slider(0.0, 1.2, value=0.7, step=0.05, label="Temperature") |
|
|
|
|
|
btn = gr.Button("Generate (Before vs After)") |
|
|
with gr.Row(): |
|
|
out_before = gr.Textbox(label="Before (Base)", lines=18) |
|
|
out_after = gr.Textbox(label="After (RL / DPO LoRA)", lines=18) |
|
|
|
|
|
btn.click( |
|
|
fn=run_both, |
|
|
inputs=[user_intent, cart, budget, urgency, brand_avoid, max_new_tokens, temperature], |
|
|
outputs=[out_before, out_after] |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|