Spaces:
Sleeping
Sleeping
File size: 7,510 Bytes
747d60a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
"""
DPO Recipe Generation API - HuggingFace Spaces
Generates personalized recipes using DPO-trained persona models.
"""
import os
import json
import re
import torch
import gradio as gr
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# Configuration
BASE_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Available personas
PERSONAS = {
"korean_spicy": {
"hf_adapter": "Hunjun/korean-spicy-dpo-adapter",
"name": "Korean Food Lover (Spicy)",
"cuisine": "korean",
"flavor": "spicy, umami, savory",
},
"mexican_vegan": {
"hf_adapter": "Hunjun/mexican-vegan-dpo-adapter",
"name": "Mexican Vegan",
"cuisine": "mexican",
"flavor": "spicy, bold, savory",
"dietary_restrictions": "vegan",
}
}
# Global model cache
_base_model = None
_tokenizer = None
_current_persona = None
_model_with_adapter = None
def get_device():
"""Determine the best available device."""
if torch.cuda.is_available():
return "cuda"
return "cpu"
def load_base_model():
"""Load the base model and tokenizer."""
global _base_model, _tokenizer
if _base_model is not None:
return
print("Loading base model and tokenizer...")
device = get_device()
_tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_ID,
token=HF_TOKEN
)
_tokenizer.pad_token = _tokenizer.eos_token
_base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
token=HF_TOKEN
)
print(f"Base model loaded on {device}")
def load_adapter(persona_id: str):
"""Load a specific persona adapter."""
global _model_with_adapter, _current_persona
if _current_persona == persona_id:
return
load_base_model()
print(f"Loading adapter for {persona_id}...")
adapter_repo = PERSONAS[persona_id]["hf_adapter"]
_model_with_adapter = PeftModel.from_pretrained(
_base_model,
adapter_repo,
token=HF_TOKEN
)
_model_with_adapter.eval()
_current_persona = persona_id
print(f"Adapter loaded: {persona_id}")
def build_prompt(persona_id: str, ingredients: str, user_request: str = "") -> str:
"""Build ChatML format prompt."""
persona = PERSONAS[persona_id]
system_msg = "You are a recipe generation AI that creates recipes based on user inventory and preferences."
diet = persona.get("dietary_restrictions", "")
if user_request:
user_msg = f"I have {ingredients}. {user_request}"
else:
user_msg = f"I have {ingredients}."
if diet:
user_msg += f" I want a {diet} {persona['cuisine']} recipe."
else:
user_msg += f" I want a {persona['cuisine']} recipe."
prompt = f"""<|im_start|>system
{system_msg}<|im_end|>
<|im_start|>user
{user_msg}<|im_end|>
<|im_start|>assistant
"""
return prompt
def parse_recipe_json(output: str) -> dict:
"""Parse recipe JSON from model output."""
try:
return json.loads(output)
except json.JSONDecodeError:
pass
json_match = re.search(r'\{[\s\S]*\}', output)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
return {
"status": "error",
"error": "Failed to parse recipe",
"raw_output": output[:500]
}
def generate_recipe(
persona: str,
ingredients: str,
user_request: str = "",
max_tokens: int = 512,
temperature: float = 0.7
) -> dict:
"""Generate a recipe using the specified persona."""
if persona not in PERSONAS:
return {"status": "error", "error": f"Unknown persona: {persona}"}
if not ingredients.strip():
return {"status": "error", "error": "Please provide at least one ingredient"}
try:
# Load adapter
load_adapter(persona)
# Build prompt
prompt = build_prompt(persona, ingredients, user_request)
# Tokenize
inputs = _tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Generate
with torch.no_grad():
outputs = _model_with_adapter.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=_tokenizer.eos_token_id
)
# Decode
generated_text = _tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
# Parse and return
result = parse_recipe_json(generated_text)
result["persona"] = persona
result["persona_name"] = PERSONAS[persona]["name"]
return result
except Exception as e:
return {
"status": "error",
"error": str(e),
"persona": persona
}
# Gradio Interface
with gr.Blocks(title="DPO Recipe Generator") as demo:
gr.Markdown("""
# DPO Recipe Generator
Generate personalized recipes using DPO-trained persona models.
**Available Personas:**
- **Korean Spicy**: Korean cuisine with emphasis on spicy flavors
- **Mexican Vegan**: Mexican cuisine, plant-based recipes
""")
with gr.Row():
with gr.Column():
persona_input = gr.Dropdown(
choices=list(PERSONAS.keys()),
value="korean_spicy",
label="Persona"
)
ingredients_input = gr.Textbox(
label="Ingredients",
placeholder="e.g., tofu, rice, gochujang, sesame oil",
lines=2
)
request_input = gr.Textbox(
label="Additional Request (optional)",
placeholder="e.g., Make something quick and spicy",
lines=2
)
with gr.Row():
max_tokens = gr.Slider(
minimum=128,
maximum=1024,
value=512,
step=64,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature"
)
generate_btn = gr.Button("Generate Recipe", variant="primary")
with gr.Column():
output = gr.JSON(label="Generated Recipe")
generate_btn.click(
fn=generate_recipe,
inputs=[persona_input, ingredients_input, request_input, max_tokens, temperature],
outputs=output
)
gr.Examples(
examples=[
["korean_spicy", "tofu, rice, gochujang, sesame oil, green onion", "Make something quick and spicy"],
["mexican_vegan", "black beans, avocado, lime, cilantro, tortillas", "Make fresh tacos"],
["korean_spicy", "chicken, kimchi, cheese, rice", "Make a fusion dish"],
["mexican_vegan", "quinoa, bell peppers, corn, black beans", "Make a healthy bowl"],
],
inputs=[persona_input, ingredients_input, request_input]
)
if __name__ == "__main__":
demo.launch()
|