vk commited on
Commit
5c5dfcc
·
1 Parent(s): 6e6fe5a

Simplify to intelligent search: remove hardcoded ingredient lists

Browse files

- Keep DialoGPT for query enhancement
- Remove all hardcoded ingredient/cuisine/meal type lists
- Pass full user query directly to TF-IDF search
- Simple but effective boosting for dessert/food word detection
- Much cleaner and more scalable approach
- Should now properly handle 'burger recipes' and 'chocolate dessert'

Files changed (1) hide show
  1. app.py +50 -222
app.py CHANGED
@@ -250,24 +250,27 @@ def load_recipes():
250
 
251
  @torch.inference_mode()
252
  def extract_query_features_with_llm(query_text, preferences="", max_minutes=30):
253
- """Use DialoGPT and enhanced rule-based extraction for intelligent feature parsing"""
254
  global tokenizer, model
255
 
256
- # Always use enhanced rule-based extraction as the foundation
257
- enhanced_features = extract_enhanced_features(query_text, preferences, max_minutes)
258
 
259
- # If model is available, use it to enhance the extraction
 
 
 
 
260
  if model is not None and tokenizer is not None:
261
  try:
262
- # Use DialoGPT conversational understanding to improve extraction
263
- conversation = f"User: I want to cook {query_text} {preferences}".strip()
264
 
265
  inputs = tokenizer.encode(conversation + tokenizer.eos_token, return_tensors="pt").to(device)
266
 
267
  # Generate a response to understand intent
268
  outputs = model.generate(
269
  inputs,
270
- max_new_tokens=50,
271
  temperature=0.7,
272
  top_p=0.9,
273
  do_sample=True,
@@ -277,161 +280,34 @@ def extract_query_features_with_llm(query_text, preferences="", max_minutes=30):
277
 
278
  response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
279
 
280
- # Extract additional insights from DialoGPT response
281
- llm_insights = extract_insights_from_response(response)
 
 
 
 
 
 
 
 
282
 
283
- # Merge enhanced features with LLM insights
284
- merged_features = merge_feature_sets(enhanced_features, llm_insights)
285
 
286
- print(f"🤖 DialoGPT-enhanced extraction: {merged_features['search_terms'][:8]}")
287
- return merged_features
288
 
289
  except Exception as e:
290
- print(f"⚠️ DialoGPT enhancement failed, using rule-based: {e}")
291
-
292
- print(f"📋 Enhanced rule-based extraction: {enhanced_features['search_terms'][:8]}")
293
- return enhanced_features
294
-
295
- def extract_enhanced_features(query_text, preferences="", max_minutes=30):
296
- """Enhanced rule-based feature extraction optimized for recipe queries"""
297
- query_lower = (query_text + " " + (preferences or "")).lower()
298
-
299
- # Comprehensive ingredient detection
300
- ingredients = detect_ingredients(query_lower)
301
 
302
- # Meal type detection with better patterns
303
- meal_types = detect_meal_types(query_lower)
304
-
305
- # Cuisine detection
306
- cuisines = detect_cuisines(query_lower)
307
-
308
- # Dietary restrictions and preferences
309
- dietary_restrictions = detect_dietary_preferences(query_lower)
310
-
311
- # Cooking styles and methods
312
- cooking_styles = detect_cooking_styles(query_lower)
313
- cooking_methods = detect_cooking_methods(query_lower)
314
-
315
- # Flavor profiles
316
- flavors = detect_flavors(query_lower)
317
 
318
  return {
319
- 'ingredients': ingredients,
320
- 'meal_types': meal_types,
321
- 'cuisines': cuisines,
322
- 'dietary_restrictions': dietary_restrictions,
323
- 'cooking_styles': cooking_styles,
324
- 'cooking_methods': cooking_methods,
325
- 'flavors': flavors,
326
  'max_minutes': max_minutes,
327
- 'search_terms': ingredients + meal_types + cuisines + dietary_restrictions + cooking_styles + cooking_methods + flavors
328
- }
329
-
330
- def detect_ingredients(query_lower):
331
- """Detect ingredients with comprehensive patterns"""
332
- ingredients = []
333
-
334
- # Comprehensive ingredient list including dessert ingredients
335
- ingredient_patterns = {
336
- 'proteins': ['chicken', 'beef', 'pork', 'fish', 'salmon', 'shrimp', 'tofu', 'eggs', 'turkey', 'lamb'],
337
- 'starches': ['rice', 'pasta', 'quinoa', 'bread', 'potatoes', 'noodles', 'flour', 'oats'],
338
- 'vegetables': ['tomatoes', 'onion', 'garlic', 'ginger', 'peppers', 'broccoli', 'spinach', 'carrots', 'mushrooms', 'avocado'],
339
- 'dessert_key': ['chocolate', 'cocoa', 'sugar', 'vanilla', 'caramel', 'honey', 'maple syrup', 'cream', 'butter'],
340
- 'fruits': ['apple', 'banana', 'berries', 'strawberry', 'blueberry', 'lemon', 'lime', 'orange'],
341
- 'dairy': ['cheese', 'milk', 'yogurt', 'cream'],
342
- 'nuts_spices': ['nuts', 'almonds', 'walnuts', 'cinnamon', 'nutmeg', 'herbs', 'basil']
343
- }
344
-
345
- for category, items in ingredient_patterns.items():
346
- for item in items:
347
- if item in query_lower:
348
- ingredients.append(item)
349
- # Special boost for dessert ingredients
350
- if category == 'dessert_key':
351
- ingredients.append(f"sweet_{item}") # Add emphasis for dessert context
352
-
353
- return list(set(ingredients))
354
-
355
- def detect_meal_types(query_lower):
356
- """Enhanced meal type detection with better patterns"""
357
- meal_patterns = {
358
- 'dessert': ['dessert', 'sweet', 'cake', 'cookie', 'pie', 'ice cream', 'pudding', 'tart', 'chocolate', 'candy'],
359
- 'breakfast': ['breakfast', 'morning', 'brunch', 'cereal', 'pancake', 'waffle'],
360
- 'lunch': ['lunch', 'midday', 'sandwich'],
361
- 'dinner': ['dinner', 'supper', 'evening'],
362
- 'snack': ['snack', 'appetizer', 'finger food'],
363
- 'drink': ['drink', 'beverage', 'smoothie', 'juice']
364
- }
365
-
366
- detected = []
367
- for meal_type, keywords in meal_patterns.items():
368
- if any(keyword in query_lower for keyword in keywords):
369
- detected.append(meal_type)
370
-
371
- return detected
372
-
373
- def detect_cuisines(query_lower):
374
- """Detect cuisine types"""
375
- cuisines = ['italian', 'mexican', 'asian', 'chinese', 'thai', 'indian', 'greek', 'french', 'mediterranean', 'american', 'japanese']
376
- return [cuisine for cuisine in cuisines if cuisine in query_lower]
377
-
378
- def detect_dietary_preferences(query_lower):
379
- """Detect dietary restrictions and preferences"""
380
- diets = ['vegetarian', 'vegan', 'healthy', 'low-carb', 'keto', 'gluten-free', 'dairy-free']
381
- return [diet for diet in diets if diet in query_lower]
382
-
383
- def detect_cooking_styles(query_lower):
384
- """Detect cooking styles and preferences"""
385
- styles = ['quick', 'easy', 'fast', 'slow', 'comfort', 'light', 'hearty', 'simple']
386
- return [style for style in styles if style in query_lower]
387
-
388
- def detect_cooking_methods(query_lower):
389
- """Detect cooking methods"""
390
- methods = ['baked', 'fried', 'grilled', 'roasted', 'steamed', 'boiled', 'sauteed']
391
- return [method for method in methods if method in query_lower]
392
-
393
- def detect_flavors(query_lower):
394
- """Detect flavor preferences"""
395
- flavors = ['sweet', 'spicy', 'savory', 'sour', 'creamy', 'crispy']
396
- return [flavor for flavor in flavors if flavor in query_lower]
397
-
398
- def extract_insights_from_response(response_text):
399
- """Extract insights from DialoGPT response"""
400
- response_lower = response_text.lower()
401
-
402
- # Look for food-related words in the response
403
- food_words = []
404
- cooking_words = []
405
-
406
- # Simple extraction from response
407
- food_indicators = ['recipe', 'cook', 'make', 'prepare', 'dish', 'meal', 'food']
408
- for indicator in food_indicators:
409
- if indicator in response_lower:
410
- cooking_words.append(indicator)
411
-
412
- return {
413
- 'ingredients': food_words,
414
- 'cooking_context': cooking_words
415
  }
416
 
417
- def merge_feature_sets(base_features, llm_insights):
418
- """Merge rule-based features with LLM insights"""
419
- # Start with base features
420
- merged = base_features.copy()
421
-
422
- # Add LLM insights if they provide new information
423
- if llm_insights.get('ingredients'):
424
- merged['ingredients'].extend(llm_insights['ingredients'])
425
- merged['ingredients'] = list(set(merged['ingredients'])) # Remove duplicates
426
-
427
- # Rebuild search terms
428
- merged['search_terms'] = (
429
- merged['ingredients'] + merged['meal_types'] + merged['cuisines'] +
430
- merged['dietary_restrictions'] + merged['cooking_styles'] +
431
- merged['cooking_methods'] + merged['flavors']
432
- )
433
-
434
- return merged
435
 
436
  def parse_llm_json_response(response_text):
437
  """Parse LLM's JSON response into structured features"""
@@ -493,7 +369,7 @@ def extract_terms_from_text(text, terms_list):
493
 
494
 
495
  def search_recipes(query_features, top_k=10):
496
- """Enhanced search for recipes matching the LLM-extracted features"""
497
  global recipes_df, vectorizer, recipe_vectors
498
 
499
  if recipes_df is None:
@@ -505,11 +381,11 @@ def search_recipes(query_features, top_k=10):
505
  if len(filtered_df) == 0:
506
  filtered_df = recipes_df.copy() # Fall back to all recipes
507
 
508
- # Create search query from all LLM-extracted terms
509
  search_query = ' '.join(query_features['search_terms'])
510
 
511
  if search_query and vectorizer is not None:
512
- # Semantic search using TF-IDF
513
  query_vector = vectorizer.transform([search_query])
514
 
515
  # Get vectors for the filtered subset by re-indexing
@@ -538,73 +414,25 @@ def search_recipes(query_features, top_k=10):
538
  print(f"⚠️ Similarity length mismatch: {len(similarities)} vs {len(filtered_df)}")
539
  filtered_df['similarity'] = 0.5
540
 
541
- # Apply intelligent boosting based on enhanced features
542
-
543
- # HIGHEST PRIORITY: Meal type matches (especially dessert)
544
- if query_features.get('meal_types'):
545
- for meal_type in query_features['meal_types']:
546
- # Check name, tags, and search text for meal type
547
- mask = (filtered_df['name'].str.lower().str.contains(meal_type, na=False) |
548
- filtered_df['tags_text'].str.contains(meal_type, na=False) |
549
- filtered_df['search_text'].str.contains(meal_type, na=False))
550
- filtered_df.loc[mask, 'similarity'] *= 3.0 # Very high boost
551
-
552
- # Special handling for desserts - comprehensive dessert detection
553
- if meal_type == 'dessert':
554
- dessert_patterns = [
555
- 'chocolate', 'cocoa', 'sugar', 'vanilla', 'cake', 'cookie', 'pie',
556
- 'sweet', 'candy', 'cream', 'frosting', 'icing', 'dessert', 'pudding',
557
- 'brownie', 'tart', 'mousse', 'custard', 'fudge', 'caramel', 'honey'
558
- ]
559
- for pattern in dessert_patterns:
560
- mask = filtered_df['search_text'].str.contains(pattern, na=False)
561
- filtered_df.loc[mask, 'similarity'] *= 2.5 # Strong dessert boost
562
-
563
- # Also check recipe names for dessert indicators
564
- dessert_name_patterns = ['cake', 'cookie', 'brownie', 'pie', 'tart', 'sweet', 'chocolate']
565
- for pattern in dessert_name_patterns:
566
- mask = filtered_df['name'].str.lower().str.contains(pattern, na=False)
567
- filtered_df.loc[mask, 'similarity'] *= 2.8
568
-
569
- # HIGH PRIORITY: Exact ingredient matches
570
- if query_features.get('ingredients'):
571
- for ingredient in query_features['ingredients']:
572
- # Regular ingredient matching
573
- mask = filtered_df['ingredients_text'].str.contains(ingredient.replace('sweet_', ''), na=False)
574
- filtered_df.loc[mask, 'similarity'] *= 2.2
575
-
576
- # Special handling for dessert ingredients with sweet_ prefix
577
- if ingredient.startswith('sweet_'):
578
- base_ingredient = ingredient.replace('sweet_', '')
579
- mask = filtered_df['ingredients_text'].str.contains(base_ingredient, na=False)
580
- # Check if recipe also has dessert context
581
- dessert_context_mask = (
582
- filtered_df['search_text'].str.contains('sweet|dessert|cake|cookie', na=False) |
583
- filtered_df['tags_text'].str.contains('dessert|sweet', na=False)
584
- )
585
- combined_mask = mask & dessert_context_mask
586
- filtered_df.loc[combined_mask, 'similarity'] *= 3.5 # Highest boost for dessert ingredients in dessert context
587
-
588
- # MEDIUM PRIORITY: Flavor matches (sweet, spicy, etc.)
589
- if query_features.get('flavors'):
590
- for flavor in query_features['flavors']:
591
- mask = filtered_df['search_text'].str.contains(flavor, na=False)
592
- multiplier = 2.0 if flavor == 'sweet' else 1.5 # Higher boost for sweet
593
- filtered_df.loc[mask, 'similarity'] *= multiplier
594
-
595
- # LOWER PRIORITY: Cuisine matches
596
- if query_features.get('cuisines'):
597
- for cuisine in query_features['cuisines']:
598
- mask = (filtered_df['tags_text'].str.contains(cuisine, na=False) |
599
- filtered_df['name'].str.lower().str.contains(cuisine, na=False))
600
- filtered_df.loc[mask, 'similarity'] *= 1.4
601
-
602
- # LOWER PRIORITY: Cooking method matches
603
- if query_features.get('cooking_methods'):
604
- for method in query_features['cooking_methods']:
605
- mask = (filtered_df['name'].str.lower().str.contains(method, na=False) |
606
- filtered_df['steps_text'].str.contains(method, na=False))
607
- filtered_df.loc[mask, 'similarity'] *= 1.3
608
 
609
  # Sort by similarity (descending)
610
  filtered_df = filtered_df.sort_values('similarity', ascending=False)
 
250
 
251
  @torch.inference_mode()
252
  def extract_query_features_with_llm(query_text, preferences="", max_minutes=30):
253
+ """Use DialoGPT to enhance query understanding, then pass full query to search"""
254
  global tokenizer, model
255
 
256
+ full_query = f"{query_text} {preferences}".strip()
 
257
 
258
+ # Start with the original query as our search terms
259
+ base_search_terms = [full_query]
260
+
261
+ # If DialoGPT is available, use it to enhance understanding
262
+ enhanced_terms = []
263
  if model is not None and tokenizer is not None:
264
  try:
265
+ # Use DialoGPT to understand context and intent
266
+ conversation = f"User: I want to cook {full_query}".strip()
267
 
268
  inputs = tokenizer.encode(conversation + tokenizer.eos_token, return_tensors="pt").to(device)
269
 
270
  # Generate a response to understand intent
271
  outputs = model.generate(
272
  inputs,
273
+ max_new_tokens=30,
274
  temperature=0.7,
275
  top_p=0.9,
276
  do_sample=True,
 
280
 
281
  response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
282
 
283
+ # Extract key food-related words from DialoGPT response
284
+ response_lower = response.lower()
285
+ food_keywords = []
286
+
287
+ # Look for food-related words in the response
288
+ food_indicators = ['recipe', 'cook', 'make', 'dish', 'meal', 'food', 'ingredient', 'cuisine']
289
+ for word in response.split():
290
+ word_clean = word.lower().strip('.,!?')
291
+ if word_clean in food_indicators or len(word_clean) > 3: # Capture potential food words
292
+ food_keywords.append(word_clean)
293
 
294
+ enhanced_terms = food_keywords[:5] # Limit to top 5 terms
 
295
 
296
+ print(f"🤖 DialoGPT enhanced with: {enhanced_terms}")
 
297
 
298
  except Exception as e:
299
+ print(f"⚠️ DialoGPT enhancement failed: {e}")
 
 
 
 
 
 
 
 
 
 
300
 
301
+ # Combine original query with enhanced terms
302
+ all_search_terms = base_search_terms + enhanced_terms
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  return {
305
+ 'original_query': full_query,
306
+ 'search_terms': all_search_terms,
 
 
 
 
 
307
  'max_minutes': max_minutes,
308
+ 'enhanced_by_llm': len(enhanced_terms) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  }
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  def parse_llm_json_response(response_text):
313
  """Parse LLM's JSON response into structured features"""
 
369
 
370
 
371
  def search_recipes(query_features, top_k=10):
372
+ """Simplified intelligent search using full query + DialoGPT enhancement"""
373
  global recipes_df, vectorizer, recipe_vectors
374
 
375
  if recipes_df is None:
 
381
  if len(filtered_df) == 0:
382
  filtered_df = recipes_df.copy() # Fall back to all recipes
383
 
384
+ # Create search query from all terms (original query + DialoGPT enhancements)
385
  search_query = ' '.join(query_features['search_terms'])
386
 
387
  if search_query and vectorizer is not None:
388
+ # Semantic search using TF-IDF on the full query
389
  query_vector = vectorizer.transform([search_query])
390
 
391
  # Get vectors for the filtered subset by re-indexing
 
414
  print(f"⚠️ Similarity length mismatch: {len(similarities)} vs {len(filtered_df)}")
415
  filtered_df['similarity'] = 0.5
416
 
417
+ # Simple boosting based on query content detection
418
+ original_query = query_features.get('original_query', '').lower()
419
+
420
+ # Boost for dessert-related queries
421
+ if any(word in original_query for word in ['dessert', 'sweet', 'chocolate', 'cake', 'cookie']):
422
+ dessert_patterns = ['chocolate', 'cake', 'cookie', 'dessert', 'sweet', 'brownie', 'pie']
423
+ for pattern in dessert_patterns:
424
+ mask = (filtered_df['name'].str.lower().str.contains(pattern, na=False) |
425
+ filtered_df['search_text'].str.contains(pattern, na=False))
426
+ filtered_df.loc[mask, 'similarity'] *= 2.0
427
+
428
+ # Boost for specific food mentions (burger, pasta, etc.)
429
+ food_words = [word for word in original_query.split() if len(word) > 3]
430
+ for word in food_words:
431
+ if word not in ['want', 'like', 'something', 'recipes', 'recipe']:
432
+ mask = (filtered_df['name'].str.lower().str.contains(word, na=False) |
433
+ filtered_df['ingredients_text'].str.contains(word, na=False) |
434
+ filtered_df['search_text'].str.contains(word, na=False))
435
+ filtered_df.loc[mask, 'similarity'] *= 1.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
  # Sort by similarity (descending)
438
  filtered_df = filtered_df.sort_values('similarity', ascending=False)