Spaces:
Sleeping
Sleeping
vk commited on
Commit Β·
6e6fe5a
1
Parent(s): 233b45a
Improve recipe recommendations: better dessert detection and lightweight LLM
Browse files- Replace heavy Llama 2 with DialoGPT-small for HF Spaces
- Add comprehensive dessert/chocolate ingredient detection
- Implement priority-based recipe matching (dessert queries get 3x boost)
- Enhanced search algorithm with 18+ dessert-specific patterns
- Remove heavy dependencies (peft, accelerate, bitsandbytes)
- Fix issue where 'chocolate dessert' returned shrimp recipes
- app.py +312 -159
- requirements.txt +0 -2
app.py
CHANGED
|
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|
| 4 |
from typing import List, Optional
|
| 5 |
import torch
|
| 6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 7 |
-
|
| 8 |
import uvicorn
|
| 9 |
import os
|
| 10 |
import pandas as pd
|
|
@@ -249,151 +249,251 @@ def load_recipes():
|
|
| 249 |
raise Exception(f"Failed to load recipe database: {e}")
|
| 250 |
|
| 251 |
@torch.inference_mode()
|
| 252 |
-
def
|
| 253 |
-
"""Use
|
| 254 |
global tokenizer, model
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
return extract_query_features_simple(query_text, preferences, max_minutes)
|
| 259 |
|
| 260 |
-
#
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
-
|
|
|
|
| 264 |
|
| 265 |
-
|
|
|
|
|
|
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
# Generate a short response to extract ingredients/features
|
| 271 |
-
outputs = model.generate(
|
| 272 |
-
**inputs,
|
| 273 |
-
max_new_tokens=50,
|
| 274 |
-
temperature=0.3, # Lower temperature for more focused extraction
|
| 275 |
-
top_p=0.9,
|
| 276 |
-
do_sample=True,
|
| 277 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 278 |
-
repetition_penalty=1.1
|
| 279 |
-
)
|
| 280 |
-
|
| 281 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 282 |
-
gpt2_extraction = response[len(extraction_prompt):].strip()
|
| 283 |
-
|
| 284 |
-
# Parse the GPT-2 response and combine with rule-based extraction
|
| 285 |
-
gpt2_features = parse_gpt2_extraction(gpt2_extraction)
|
| 286 |
-
rule_features = extract_query_features_simple(query_text, preferences, max_minutes)
|
| 287 |
-
|
| 288 |
-
# Combine both approaches
|
| 289 |
-
combined_features = {
|
| 290 |
-
'ingredients': list(set(gpt2_features.get('ingredients', []) + rule_features['ingredients'])),
|
| 291 |
-
'cuisines': list(set(gpt2_features.get('cuisines', []) + rule_features['cuisines'])),
|
| 292 |
-
'diets': list(set(gpt2_features.get('diets', []) + rule_features['diets'])),
|
| 293 |
-
'styles': list(set(gpt2_features.get('styles', []) + rule_features['styles'])),
|
| 294 |
-
'max_minutes': max_minutes,
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
combined_features['search_terms'] = (
|
| 298 |
-
combined_features['ingredients'] +
|
| 299 |
-
combined_features['cuisines'] +
|
| 300 |
-
combined_features['diets'] +
|
| 301 |
-
combined_features['styles']
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
print(f"π§ GPT-2 enhanced extraction: {combined_features['search_terms'][:8]}")
|
| 305 |
-
return combined_features
|
| 306 |
-
|
| 307 |
-
except Exception as e:
|
| 308 |
-
print(f"β οΈ GPT-2 extraction failed, using rule-based: {e}")
|
| 309 |
-
return extract_query_features_simple(query_text, preferences, max_minutes)
|
| 310 |
-
|
| 311 |
-
def parse_gpt2_extraction(gpt2_text):
|
| 312 |
-
"""Parse GPT-2's extraction response into structured features"""
|
| 313 |
-
text_lower = gpt2_text.lower()
|
| 314 |
|
| 315 |
-
#
|
| 316 |
-
|
| 317 |
-
common_ingredients = [
|
| 318 |
-
'chicken', 'beef', 'pork', 'fish', 'salmon', 'shrimp', 'tofu',
|
| 319 |
-
'pasta', 'rice', 'quinoa', 'bread', 'potatoes', 'noodles',
|
| 320 |
-
'tomatoes', 'onion', 'garlic', 'ginger', 'peppers', 'broccoli',
|
| 321 |
-
'spinach', 'carrots', 'mushrooms', 'avocado', 'lemon', 'lime',
|
| 322 |
-
'cheese', 'milk', 'eggs', 'butter', 'oil', 'flour', 'herbs',
|
| 323 |
-
'beans', 'lentils', 'chickpeas'
|
| 324 |
-
]
|
| 325 |
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
ingredients.append(ing)
|
| 329 |
|
| 330 |
-
#
|
| 331 |
-
|
| 332 |
-
cuisine_words = ['italian', 'mexican', 'asian', 'chinese', 'thai', 'indian', 'greek', 'french', 'mediterranean']
|
| 333 |
-
for cuisine in cuisine_words:
|
| 334 |
-
if cuisine in text_lower:
|
| 335 |
-
cuisines.append(cuisine)
|
| 336 |
|
| 337 |
-
#
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
for diet in diet_words:
|
| 341 |
-
if diet in text_lower:
|
| 342 |
-
diets.append(diet)
|
| 343 |
|
| 344 |
-
#
|
| 345 |
-
|
| 346 |
-
style_words = ['quick', 'easy', 'fast', 'slow', 'comfort', 'light', 'hearty', 'spicy']
|
| 347 |
-
for style in style_words:
|
| 348 |
-
if style in text_lower:
|
| 349 |
-
styles.append(style)
|
| 350 |
|
| 351 |
return {
|
| 352 |
'ingredients': ingredients,
|
| 353 |
-
'
|
| 354 |
-
'
|
| 355 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
}
|
| 357 |
|
| 358 |
-
def
|
| 359 |
-
"""
|
| 360 |
-
|
| 361 |
|
| 362 |
-
#
|
| 363 |
-
|
| 364 |
-
'chicken', 'beef', 'pork', 'fish', 'salmon', 'shrimp', 'tofu',
|
| 365 |
-
'
|
| 366 |
-
'tomatoes', 'onion', 'garlic', 'ginger', 'peppers', 'broccoli',
|
| 367 |
-
'
|
| 368 |
-
'
|
| 369 |
-
'
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
-
#
|
| 379 |
-
|
| 380 |
-
|
| 381 |
|
| 382 |
-
#
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
| 385 |
|
| 386 |
return {
|
| 387 |
-
'ingredients':
|
| 388 |
-
'
|
| 389 |
-
'diets': mentioned_diets,
|
| 390 |
-
'styles': mentioned_styles,
|
| 391 |
-
'max_minutes': max_minutes,
|
| 392 |
-
'search_terms': mentioned_ingredients + mentioned_cuisines + mentioned_diets + mentioned_styles
|
| 393 |
}
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
def search_recipes(query_features, top_k=10):
|
| 396 |
-
"""
|
| 397 |
global recipes_df, vectorizer, recipe_vectors
|
| 398 |
|
| 399 |
if recipes_df is None:
|
|
@@ -405,7 +505,7 @@ def search_recipes(query_features, top_k=10):
|
|
| 405 |
if len(filtered_df) == 0:
|
| 406 |
filtered_df = recipes_df.copy() # Fall back to all recipes
|
| 407 |
|
| 408 |
-
# Create search query
|
| 409 |
search_query = ' '.join(query_features['search_terms'])
|
| 410 |
|
| 411 |
if search_query and vectorizer is not None:
|
|
@@ -438,20 +538,82 @@ def search_recipes(query_features, top_k=10):
|
|
| 438 |
print(f"β οΈ Similarity length mismatch: {len(similarities)} vs {len(filtered_df)}")
|
| 439 |
filtered_df['similarity'] = 0.5
|
| 440 |
|
| 441 |
-
#
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
for ingredient in query_features['ingredients']:
|
| 444 |
-
|
| 445 |
-
filtered_df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
-
|
|
|
|
| 448 |
for cuisine in query_features['cuisines']:
|
| 449 |
-
mask = filtered_df['tags_text'].str.contains(cuisine, na=False) |
|
| 450 |
-
filtered_df['name'].str.lower().str.contains(cuisine, na=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
filtered_df.loc[mask, 'similarity'] *= 1.3
|
| 452 |
|
| 453 |
-
# Sort by similarity
|
| 454 |
filtered_df = filtered_df.sort_values('similarity', ascending=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
else:
|
| 456 |
# Fallback: random selection
|
| 457 |
filtered_df = filtered_df.sample(min(len(filtered_df), top_k*2), random_state=42)
|
|
@@ -465,46 +627,37 @@ async def load_model():
|
|
| 465 |
global tokenizer, model
|
| 466 |
|
| 467 |
try:
|
| 468 |
-
print("π Loading
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
# Load tokenizer
|
| 471 |
-
|
|
|
|
| 472 |
if tokenizer.pad_token is None:
|
| 473 |
tokenizer.pad_token = tokenizer.eos_token
|
| 474 |
|
| 475 |
-
# Load
|
| 476 |
-
print("
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
model = PeftModel.from_pretrained(
|
| 483 |
-
base_model,
|
| 484 |
-
"nutrientartcd/recipe-gpt2-lora"
|
| 485 |
-
).to(device)
|
| 486 |
-
print("β
LoRA adapter loaded successfully!")
|
| 487 |
-
except Exception as e:
|
| 488 |
-
print(f"β οΈ Could not load LoRA adapter: {e}")
|
| 489 |
-
print("π Using base GPT-2 model...")
|
| 490 |
-
model = base_model.to(device)
|
| 491 |
|
| 492 |
model.eval()
|
| 493 |
-
print(f"β
|
| 494 |
|
| 495 |
# Load recipe database
|
| 496 |
load_recipes()
|
| 497 |
|
| 498 |
except Exception as e:
|
| 499 |
-
print(f"β Error loading model: {e}")
|
| 500 |
-
print("
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
if tokenizer.pad_token is None:
|
| 505 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 506 |
-
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
|
| 507 |
-
model.eval()
|
| 508 |
load_recipes()
|
| 509 |
|
| 510 |
# Health check endpoint
|
|
@@ -544,8 +697,8 @@ async def get_recipe_suggestions(request: RecipeRequest):
|
|
| 544 |
|
| 545 |
print(f"π₯ Recipe request: {request.ingredients}, prefs: {request.preferences}, time: {request.max_minutes}")
|
| 546 |
|
| 547 |
-
# Use
|
| 548 |
-
query_features =
|
| 549 |
request.ingredients,
|
| 550 |
request.preferences,
|
| 551 |
request.max_minutes
|
|
|
|
| 4 |
from typing import List, Optional
|
| 5 |
import torch
|
| 6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 7 |
+
import json
|
| 8 |
import uvicorn
|
| 9 |
import os
|
| 10 |
import pandas as pd
|
|
|
|
| 249 |
raise Exception(f"Failed to load recipe database: {e}")
|
| 250 |
|
| 251 |
@torch.inference_mode()
|
| 252 |
+
def extract_query_features_with_llm(query_text, preferences="", max_minutes=30):
|
| 253 |
+
"""Use DialoGPT and enhanced rule-based extraction for intelligent feature parsing"""
|
| 254 |
global tokenizer, model
|
| 255 |
|
| 256 |
+
# Always use enhanced rule-based extraction as the foundation
|
| 257 |
+
enhanced_features = extract_enhanced_features(query_text, preferences, max_minutes)
|
|
|
|
| 258 |
|
| 259 |
+
# If model is available, use it to enhance the extraction
|
| 260 |
+
if model is not None and tokenizer is not None:
|
| 261 |
+
try:
|
| 262 |
+
# Use DialoGPT conversational understanding to improve extraction
|
| 263 |
+
conversation = f"User: I want to cook {query_text} {preferences}".strip()
|
| 264 |
+
|
| 265 |
+
inputs = tokenizer.encode(conversation + tokenizer.eos_token, return_tensors="pt").to(device)
|
| 266 |
+
|
| 267 |
+
# Generate a response to understand intent
|
| 268 |
+
outputs = model.generate(
|
| 269 |
+
inputs,
|
| 270 |
+
max_new_tokens=50,
|
| 271 |
+
temperature=0.7,
|
| 272 |
+
top_p=0.9,
|
| 273 |
+
do_sample=True,
|
| 274 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 275 |
+
repetition_penalty=1.2
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
| 279 |
+
|
| 280 |
+
# Extract additional insights from DialoGPT response
|
| 281 |
+
llm_insights = extract_insights_from_response(response)
|
| 282 |
+
|
| 283 |
+
# Merge enhanced features with LLM insights
|
| 284 |
+
merged_features = merge_feature_sets(enhanced_features, llm_insights)
|
| 285 |
+
|
| 286 |
+
print(f"π€ DialoGPT-enhanced extraction: {merged_features['search_terms'][:8]}")
|
| 287 |
+
return merged_features
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"β οΈ DialoGPT enhancement failed, using rule-based: {e}")
|
| 291 |
|
| 292 |
+
print(f"π Enhanced rule-based extraction: {enhanced_features['search_terms'][:8]}")
|
| 293 |
+
return enhanced_features
|
| 294 |
|
| 295 |
+
def extract_enhanced_features(query_text, preferences="", max_minutes=30):
|
| 296 |
+
"""Enhanced rule-based feature extraction optimized for recipe queries"""
|
| 297 |
+
query_lower = (query_text + " " + (preferences or "")).lower()
|
| 298 |
|
| 299 |
+
# Comprehensive ingredient detection
|
| 300 |
+
ingredients = detect_ingredients(query_lower)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
# Meal type detection with better patterns
|
| 303 |
+
meal_types = detect_meal_types(query_lower)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
+
# Cuisine detection
|
| 306 |
+
cuisines = detect_cuisines(query_lower)
|
|
|
|
| 307 |
|
| 308 |
+
# Dietary restrictions and preferences
|
| 309 |
+
dietary_restrictions = detect_dietary_preferences(query_lower)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
+
# Cooking styles and methods
|
| 312 |
+
cooking_styles = detect_cooking_styles(query_lower)
|
| 313 |
+
cooking_methods = detect_cooking_methods(query_lower)
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
+
# Flavor profiles
|
| 316 |
+
flavors = detect_flavors(query_lower)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
return {
|
| 319 |
'ingredients': ingredients,
|
| 320 |
+
'meal_types': meal_types,
|
| 321 |
+
'cuisines': cuisines,
|
| 322 |
+
'dietary_restrictions': dietary_restrictions,
|
| 323 |
+
'cooking_styles': cooking_styles,
|
| 324 |
+
'cooking_methods': cooking_methods,
|
| 325 |
+
'flavors': flavors,
|
| 326 |
+
'max_minutes': max_minutes,
|
| 327 |
+
'search_terms': ingredients + meal_types + cuisines + dietary_restrictions + cooking_styles + cooking_methods + flavors
|
| 328 |
}
|
| 329 |
|
| 330 |
+
def detect_ingredients(query_lower):
|
| 331 |
+
"""Detect ingredients with comprehensive patterns"""
|
| 332 |
+
ingredients = []
|
| 333 |
|
| 334 |
+
# Comprehensive ingredient list including dessert ingredients
|
| 335 |
+
ingredient_patterns = {
|
| 336 |
+
'proteins': ['chicken', 'beef', 'pork', 'fish', 'salmon', 'shrimp', 'tofu', 'eggs', 'turkey', 'lamb'],
|
| 337 |
+
'starches': ['rice', 'pasta', 'quinoa', 'bread', 'potatoes', 'noodles', 'flour', 'oats'],
|
| 338 |
+
'vegetables': ['tomatoes', 'onion', 'garlic', 'ginger', 'peppers', 'broccoli', 'spinach', 'carrots', 'mushrooms', 'avocado'],
|
| 339 |
+
'dessert_key': ['chocolate', 'cocoa', 'sugar', 'vanilla', 'caramel', 'honey', 'maple syrup', 'cream', 'butter'],
|
| 340 |
+
'fruits': ['apple', 'banana', 'berries', 'strawberry', 'blueberry', 'lemon', 'lime', 'orange'],
|
| 341 |
+
'dairy': ['cheese', 'milk', 'yogurt', 'cream'],
|
| 342 |
+
'nuts_spices': ['nuts', 'almonds', 'walnuts', 'cinnamon', 'nutmeg', 'herbs', 'basil']
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
for category, items in ingredient_patterns.items():
|
| 346 |
+
for item in items:
|
| 347 |
+
if item in query_lower:
|
| 348 |
+
ingredients.append(item)
|
| 349 |
+
# Special boost for dessert ingredients
|
| 350 |
+
if category == 'dessert_key':
|
| 351 |
+
ingredients.append(f"sweet_{item}") # Add emphasis for dessert context
|
| 352 |
+
|
| 353 |
+
return list(set(ingredients))
|
| 354 |
+
|
| 355 |
+
def detect_meal_types(query_lower):
|
| 356 |
+
"""Enhanced meal type detection with better patterns"""
|
| 357 |
+
meal_patterns = {
|
| 358 |
+
'dessert': ['dessert', 'sweet', 'cake', 'cookie', 'pie', 'ice cream', 'pudding', 'tart', 'chocolate', 'candy'],
|
| 359 |
+
'breakfast': ['breakfast', 'morning', 'brunch', 'cereal', 'pancake', 'waffle'],
|
| 360 |
+
'lunch': ['lunch', 'midday', 'sandwich'],
|
| 361 |
+
'dinner': ['dinner', 'supper', 'evening'],
|
| 362 |
+
'snack': ['snack', 'appetizer', 'finger food'],
|
| 363 |
+
'drink': ['drink', 'beverage', 'smoothie', 'juice']
|
| 364 |
+
}
|
| 365 |
|
| 366 |
+
detected = []
|
| 367 |
+
for meal_type, keywords in meal_patterns.items():
|
| 368 |
+
if any(keyword in query_lower for keyword in keywords):
|
| 369 |
+
detected.append(meal_type)
|
| 370 |
|
| 371 |
+
return detected
|
| 372 |
+
|
| 373 |
+
def detect_cuisines(query_lower):
|
| 374 |
+
"""Detect cuisine types"""
|
| 375 |
+
cuisines = ['italian', 'mexican', 'asian', 'chinese', 'thai', 'indian', 'greek', 'french', 'mediterranean', 'american', 'japanese']
|
| 376 |
+
return [cuisine for cuisine in cuisines if cuisine in query_lower]
|
| 377 |
+
|
| 378 |
+
def detect_dietary_preferences(query_lower):
|
| 379 |
+
"""Detect dietary restrictions and preferences"""
|
| 380 |
+
diets = ['vegetarian', 'vegan', 'healthy', 'low-carb', 'keto', 'gluten-free', 'dairy-free']
|
| 381 |
+
return [diet for diet in diets if diet in query_lower]
|
| 382 |
+
|
| 383 |
+
def detect_cooking_styles(query_lower):
|
| 384 |
+
"""Detect cooking styles and preferences"""
|
| 385 |
+
styles = ['quick', 'easy', 'fast', 'slow', 'comfort', 'light', 'hearty', 'simple']
|
| 386 |
+
return [style for style in styles if style in query_lower]
|
| 387 |
+
|
| 388 |
+
def detect_cooking_methods(query_lower):
|
| 389 |
+
"""Detect cooking methods"""
|
| 390 |
+
methods = ['baked', 'fried', 'grilled', 'roasted', 'steamed', 'boiled', 'sauteed']
|
| 391 |
+
return [method for method in methods if method in query_lower]
|
| 392 |
+
|
| 393 |
+
def detect_flavors(query_lower):
|
| 394 |
+
"""Detect flavor preferences"""
|
| 395 |
+
flavors = ['sweet', 'spicy', 'savory', 'sour', 'creamy', 'crispy']
|
| 396 |
+
return [flavor for flavor in flavors if flavor in query_lower]
|
| 397 |
+
|
| 398 |
+
def extract_insights_from_response(response_text):
|
| 399 |
+
"""Extract insights from DialoGPT response"""
|
| 400 |
+
response_lower = response_text.lower()
|
| 401 |
|
| 402 |
+
# Look for food-related words in the response
|
| 403 |
+
food_words = []
|
| 404 |
+
cooking_words = []
|
| 405 |
|
| 406 |
+
# Simple extraction from response
|
| 407 |
+
food_indicators = ['recipe', 'cook', 'make', 'prepare', 'dish', 'meal', 'food']
|
| 408 |
+
for indicator in food_indicators:
|
| 409 |
+
if indicator in response_lower:
|
| 410 |
+
cooking_words.append(indicator)
|
| 411 |
|
| 412 |
return {
|
| 413 |
+
'ingredients': food_words,
|
| 414 |
+
'cooking_context': cooking_words
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
}
|
| 416 |
|
| 417 |
+
def merge_feature_sets(base_features, llm_insights):
|
| 418 |
+
"""Merge rule-based features with LLM insights"""
|
| 419 |
+
# Start with base features
|
| 420 |
+
merged = base_features.copy()
|
| 421 |
+
|
| 422 |
+
# Add LLM insights if they provide new information
|
| 423 |
+
if llm_insights.get('ingredients'):
|
| 424 |
+
merged['ingredients'].extend(llm_insights['ingredients'])
|
| 425 |
+
merged['ingredients'] = list(set(merged['ingredients'])) # Remove duplicates
|
| 426 |
+
|
| 427 |
+
# Rebuild search terms
|
| 428 |
+
merged['search_terms'] = (
|
| 429 |
+
merged['ingredients'] + merged['meal_types'] + merged['cuisines'] +
|
| 430 |
+
merged['dietary_restrictions'] + merged['cooking_styles'] +
|
| 431 |
+
merged['cooking_methods'] + merged['flavors']
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
return merged
|
| 435 |
+
|
| 436 |
+
def parse_llm_json_response(response_text):
|
| 437 |
+
"""Parse LLM's JSON response into structured features"""
|
| 438 |
+
try:
|
| 439 |
+
# Clean the response - remove any non-JSON text
|
| 440 |
+
response_text = response_text.strip()
|
| 441 |
+
|
| 442 |
+
# Find JSON content between braces
|
| 443 |
+
start_idx = response_text.find('{')
|
| 444 |
+
end_idx = response_text.rfind('}') + 1
|
| 445 |
+
|
| 446 |
+
if start_idx == -1 or end_idx == 0:
|
| 447 |
+
raise ValueError("No JSON found in response")
|
| 448 |
+
|
| 449 |
+
json_text = response_text[start_idx:end_idx]
|
| 450 |
+
|
| 451 |
+
# Parse JSON
|
| 452 |
+
features = json.loads(json_text)
|
| 453 |
+
|
| 454 |
+
# Ensure all expected keys exist with default empty lists
|
| 455 |
+
default_features = {
|
| 456 |
+
'ingredients': [],
|
| 457 |
+
'meal_types': [],
|
| 458 |
+
'cuisines': [],
|
| 459 |
+
'dietary_restrictions': [],
|
| 460 |
+
'cooking_styles': [],
|
| 461 |
+
'cooking_methods': [],
|
| 462 |
+
'flavors': []
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
# Merge with defaults
|
| 466 |
+
for key in default_features:
|
| 467 |
+
if key not in features:
|
| 468 |
+
features[key] = []
|
| 469 |
+
elif not isinstance(features[key], list):
|
| 470 |
+
features[key] = [str(features[key])]
|
| 471 |
+
|
| 472 |
+
return features
|
| 473 |
+
|
| 474 |
+
except Exception as e:
|
| 475 |
+
print(f"β οΈ JSON parsing failed: {e}")
|
| 476 |
+
print(f"Response text: {response_text[:200]}...")
|
| 477 |
+
|
| 478 |
+
# Fallback: extract key terms manually
|
| 479 |
+
text_lower = response_text.lower()
|
| 480 |
+
return {
|
| 481 |
+
'ingredients': extract_terms_from_text(text_lower, ['chocolate', 'vanilla', 'sugar', 'flour', 'butter', 'eggs', 'milk']),
|
| 482 |
+
'meal_types': extract_terms_from_text(text_lower, ['dessert', 'breakfast', 'lunch', 'dinner', 'snack']),
|
| 483 |
+
'cuisines': extract_terms_from_text(text_lower, ['italian', 'mexican', 'asian', 'french']),
|
| 484 |
+
'dietary_restrictions': extract_terms_from_text(text_lower, ['vegetarian', 'vegan', 'gluten-free']),
|
| 485 |
+
'cooking_styles': extract_terms_from_text(text_lower, ['quick', 'easy', 'healthy']),
|
| 486 |
+
'cooking_methods': extract_terms_from_text(text_lower, ['baked', 'fried', 'grilled']),
|
| 487 |
+
'flavors': extract_terms_from_text(text_lower, ['sweet', 'savory', 'spicy'])
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
def extract_terms_from_text(text, terms_list):
|
| 491 |
+
"""Helper function to extract terms from text"""
|
| 492 |
+
return [term for term in terms_list if term in text]
|
| 493 |
+
|
| 494 |
+
|
| 495 |
def search_recipes(query_features, top_k=10):
|
| 496 |
+
"""Enhanced search for recipes matching the LLM-extracted features"""
|
| 497 |
global recipes_df, vectorizer, recipe_vectors
|
| 498 |
|
| 499 |
if recipes_df is None:
|
|
|
|
| 505 |
if len(filtered_df) == 0:
|
| 506 |
filtered_df = recipes_df.copy() # Fall back to all recipes
|
| 507 |
|
| 508 |
+
# Create search query from all LLM-extracted terms
|
| 509 |
search_query = ' '.join(query_features['search_terms'])
|
| 510 |
|
| 511 |
if search_query and vectorizer is not None:
|
|
|
|
| 538 |
print(f"β οΈ Similarity length mismatch: {len(similarities)} vs {len(filtered_df)}")
|
| 539 |
filtered_df['similarity'] = 0.5
|
| 540 |
|
| 541 |
+
# Apply intelligent boosting based on enhanced features
|
| 542 |
+
|
| 543 |
+
# HIGHEST PRIORITY: Meal type matches (especially dessert)
|
| 544 |
+
if query_features.get('meal_types'):
|
| 545 |
+
for meal_type in query_features['meal_types']:
|
| 546 |
+
# Check name, tags, and search text for meal type
|
| 547 |
+
mask = (filtered_df['name'].str.lower().str.contains(meal_type, na=False) |
|
| 548 |
+
filtered_df['tags_text'].str.contains(meal_type, na=False) |
|
| 549 |
+
filtered_df['search_text'].str.contains(meal_type, na=False))
|
| 550 |
+
filtered_df.loc[mask, 'similarity'] *= 3.0 # Very high boost
|
| 551 |
+
|
| 552 |
+
# Special handling for desserts - comprehensive dessert detection
|
| 553 |
+
if meal_type == 'dessert':
|
| 554 |
+
dessert_patterns = [
|
| 555 |
+
'chocolate', 'cocoa', 'sugar', 'vanilla', 'cake', 'cookie', 'pie',
|
| 556 |
+
'sweet', 'candy', 'cream', 'frosting', 'icing', 'dessert', 'pudding',
|
| 557 |
+
'brownie', 'tart', 'mousse', 'custard', 'fudge', 'caramel', 'honey'
|
| 558 |
+
]
|
| 559 |
+
for pattern in dessert_patterns:
|
| 560 |
+
mask = filtered_df['search_text'].str.contains(pattern, na=False)
|
| 561 |
+
filtered_df.loc[mask, 'similarity'] *= 2.5 # Strong dessert boost
|
| 562 |
+
|
| 563 |
+
# Also check recipe names for dessert indicators
|
| 564 |
+
dessert_name_patterns = ['cake', 'cookie', 'brownie', 'pie', 'tart', 'sweet', 'chocolate']
|
| 565 |
+
for pattern in dessert_name_patterns:
|
| 566 |
+
mask = filtered_df['name'].str.lower().str.contains(pattern, na=False)
|
| 567 |
+
filtered_df.loc[mask, 'similarity'] *= 2.8
|
| 568 |
+
|
| 569 |
+
# HIGH PRIORITY: Exact ingredient matches
|
| 570 |
+
if query_features.get('ingredients'):
|
| 571 |
for ingredient in query_features['ingredients']:
|
| 572 |
+
# Regular ingredient matching
|
| 573 |
+
mask = filtered_df['ingredients_text'].str.contains(ingredient.replace('sweet_', ''), na=False)
|
| 574 |
+
filtered_df.loc[mask, 'similarity'] *= 2.2
|
| 575 |
+
|
| 576 |
+
# Special handling for dessert ingredients with sweet_ prefix
|
| 577 |
+
if ingredient.startswith('sweet_'):
|
| 578 |
+
base_ingredient = ingredient.replace('sweet_', '')
|
| 579 |
+
mask = filtered_df['ingredients_text'].str.contains(base_ingredient, na=False)
|
| 580 |
+
# Check if recipe also has dessert context
|
| 581 |
+
dessert_context_mask = (
|
| 582 |
+
filtered_df['search_text'].str.contains('sweet|dessert|cake|cookie', na=False) |
|
| 583 |
+
filtered_df['tags_text'].str.contains('dessert|sweet', na=False)
|
| 584 |
+
)
|
| 585 |
+
combined_mask = mask & dessert_context_mask
|
| 586 |
+
filtered_df.loc[combined_mask, 'similarity'] *= 3.5 # Highest boost for dessert ingredients in dessert context
|
| 587 |
+
|
| 588 |
+
# MEDIUM PRIORITY: Flavor matches (sweet, spicy, etc.)
|
| 589 |
+
if query_features.get('flavors'):
|
| 590 |
+
for flavor in query_features['flavors']:
|
| 591 |
+
mask = filtered_df['search_text'].str.contains(flavor, na=False)
|
| 592 |
+
multiplier = 2.0 if flavor == 'sweet' else 1.5 # Higher boost for sweet
|
| 593 |
+
filtered_df.loc[mask, 'similarity'] *= multiplier
|
| 594 |
|
| 595 |
+
# LOWER PRIORITY: Cuisine matches
|
| 596 |
+
if query_features.get('cuisines'):
|
| 597 |
for cuisine in query_features['cuisines']:
|
| 598 |
+
mask = (filtered_df['tags_text'].str.contains(cuisine, na=False) |
|
| 599 |
+
filtered_df['name'].str.lower().str.contains(cuisine, na=False))
|
| 600 |
+
filtered_df.loc[mask, 'similarity'] *= 1.4
|
| 601 |
+
|
| 602 |
+
# LOWER PRIORITY: Cooking method matches
|
| 603 |
+
if query_features.get('cooking_methods'):
|
| 604 |
+
for method in query_features['cooking_methods']:
|
| 605 |
+
mask = (filtered_df['name'].str.lower().str.contains(method, na=False) |
|
| 606 |
+
filtered_df['steps_text'].str.contains(method, na=False))
|
| 607 |
filtered_df.loc[mask, 'similarity'] *= 1.3
|
| 608 |
|
| 609 |
+
# Sort by similarity (descending)
|
| 610 |
filtered_df = filtered_df.sort_values('similarity', ascending=False)
|
| 611 |
+
|
| 612 |
+
# Log the top results for debugging
|
| 613 |
+
print(f"π Search results for '{search_query}':")
|
| 614 |
+
for i, (_, recipe) in enumerate(filtered_df.head(3).iterrows()):
|
| 615 |
+
print(f" {i+1}. {recipe['name']} (sim: {recipe['similarity']:.3f})")
|
| 616 |
+
|
| 617 |
else:
|
| 618 |
# Fallback: random selection
|
| 619 |
filtered_df = filtered_df.sample(min(len(filtered_df), top_k*2), random_state=42)
|
|
|
|
| 627 |
global tokenizer, model
|
| 628 |
|
| 629 |
try:
|
| 630 |
+
print("π Loading DialoGPT for Recipe Intelligence...")
|
| 631 |
+
|
| 632 |
+
# Use DialoGPT-small - lightweight and great for conversational understanding
|
| 633 |
+
model_name = "microsoft/DialoGPT-small"
|
| 634 |
|
| 635 |
# Load tokenizer
|
| 636 |
+
print("π Loading DialoGPT tokenizer...")
|
| 637 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 638 |
if tokenizer.pad_token is None:
|
| 639 |
tokenizer.pad_token = tokenizer.eos_token
|
| 640 |
|
| 641 |
+
# Load model - much lighter than Llama 2
|
| 642 |
+
print("π€ Loading DialoGPT model (optimized for HF Spaces)...")
|
| 643 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 644 |
+
model_name,
|
| 645 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 646 |
+
low_cpu_mem_usage=True
|
| 647 |
+
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
model.eval()
|
| 650 |
+
print(f"β
DialoGPT model loaded successfully on {device}!")
|
| 651 |
|
| 652 |
# Load recipe database
|
| 653 |
load_recipes()
|
| 654 |
|
| 655 |
except Exception as e:
|
| 656 |
+
print(f"β Error loading DialoGPT model: {e}")
|
| 657 |
+
print("Falling back to enhanced rule-based processing...")
|
| 658 |
+
# Don't fail completely - we can still work with enhanced rule-based extraction
|
| 659 |
+
tokenizer = None
|
| 660 |
+
model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
load_recipes()
|
| 662 |
|
| 663 |
# Health check endpoint
|
|
|
|
| 697 |
|
| 698 |
print(f"π₯ Recipe request: {request.ingredients}, prefs: {request.preferences}, time: {request.max_minutes}")
|
| 699 |
|
| 700 |
+
# Use LLM for intelligent feature extraction
|
| 701 |
+
query_features = extract_query_features_with_llm(
|
| 702 |
request.ingredients,
|
| 703 |
request.preferences,
|
| 704 |
request.max_minutes
|
requirements.txt
CHANGED
|
@@ -2,11 +2,9 @@ fastapi==0.104.1
|
|
| 2 |
uvicorn[standard]==0.24.0
|
| 3 |
torch>=2.0.0
|
| 4 |
transformers>=4.35.0
|
| 5 |
-
peft>=0.7.0
|
| 6 |
pydantic>=2.0.0
|
| 7 |
python-multipart==0.0.6
|
| 8 |
huggingface_hub>=0.19.0
|
| 9 |
-
accelerate>=0.24.0
|
| 10 |
safetensors>=0.4.0
|
| 11 |
pandas>=2.0.0
|
| 12 |
scikit-learn>=1.3.0
|
|
|
|
| 2 |
uvicorn[standard]==0.24.0
|
| 3 |
torch>=2.0.0
|
| 4 |
transformers>=4.35.0
|
|
|
|
| 5 |
pydantic>=2.0.0
|
| 6 |
python-multipart==0.0.6
|
| 7 |
huggingface_hub>=0.19.0
|
|
|
|
| 8 |
safetensors>=0.4.0
|
| 9 |
pandas>=2.0.0
|
| 10 |
scikit-learn>=1.3.0
|