Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, util | |
| import faiss | |
| import numpy as np | |
| from transformers import pipeline | |
| import time | |
| # --- 1. DATA LOADING AND PREPROCESSING --- | |
| print("===== Application Startup =====") | |
| start_time = time.time() | |
| # Load dataset and limit to the first 20,000 rows | |
| dataset = load_dataset("corbt/all-recipes", split="train[:20000]") | |
| # Preprocessing functions to extract features from the raw text | |
| def extract_each_feature(sample): | |
| full_text = sample['input'] | |
| # User's fix: Use "\n" instead of "\\n" to correctly find the title | |
| title = full_text[:full_text.find("\n")] | |
| ingredients = "Not available" | |
| directions = "Not available" | |
| ingredients_start_index = full_text.find("Ingredients:") | |
| directions_start_index = full_text.find("Directions:") | |
| if ingredients_start_index != -1 and directions_start_index != -1: | |
| ingredients = full_text[ingredients_start_index + len("Ingredients:"):directions_start_index].strip() | |
| if directions_start_index != -1: | |
| directions_raw = full_text[directions_start_index + len("Directions:"):].strip() | |
| next_ing_index = directions_raw.find("Ingredients:") | |
| next_dir_index = directions_raw.find("Directions:") | |
| cut_off_indices = [idx for idx in [next_ing_index, next_dir_index] if idx != -1] | |
| if cut_off_indices: | |
| cut_off_point = min(cut_off_indices) | |
| directions = directions_raw[:cut_off_point].strip() | |
| else: | |
| directions = directions_raw | |
| return { | |
| "title": title, | |
| "ingredients": ingredients, | |
| "directions": directions, | |
| } | |
| # Apply preprocessing | |
| dataset = dataset.map(extract_each_feature) | |
| # --- 2. EMBEDDING AND RECOMMENDATION ENGINE --- | |
| print("Loading embedding model...") | |
| model_name = "all-MiniLM-L6-v2" | |
| embedding_model = SentenceTransformer(f"sentence-transformers/{model_name}") | |
| index_file = "recipe_index.faiss" | |
| print(f"Loading FAISS index from {index_file}...") | |
| index = faiss.read_index(index_file) | |
| print(f"Index is ready. Total vectors in index: {index.ntotal}") | |
| # --- 3. SYNTHETIC GENERATION (IMPROVED) --- | |
| print("Loading generative model...") | |
| generator = pipeline('text-generation', model='gpt2') | |
| def get_recommendations_and_generate(query_ingredients, k=3): | |
| # 1. Get Recommendations | |
| query_vector = embedding_model.encode([query_ingredients]) | |
| query_vector = np.array(query_vector, dtype=np.float32) | |
| distances, indices = index.search(query_vector, k) | |
| results = [] | |
| for idx_numpy in indices[0]: | |
| idx = int(idx_numpy) | |
| recipe = { | |
| "title": dataset[idx]['title'], | |
| "ingredients": dataset[idx]['ingredients'], | |
| "directions": dataset[idx]['directions'] | |
| } | |
| results.append(recipe) | |
| while len(results) < 3: | |
| results.append({"title": "No recipe found", "ingredients": "", "directions": ""}) | |
| # 2. Generate 10 new recipe ideas with a simpler, more direct prompt | |
| prompt = f"Write a complete recipe that includes a title, a list of ingredients, and step-by-step directions. The recipe must use the following ingredients: {query_ingredients}." | |
| # Optimized for speed by reducing max_new_tokens | |
| generated_outputs = generator(prompt, max_new_tokens=180, num_return_sequences=10, pad_token_id=50256) | |
| # 3. Find the best recipe out of the 10 generated | |
| generated_texts = [output['generated_text'].replace(prompt, "").strip() for output in generated_outputs] | |
| # Embed all 10 generated texts | |
| generated_embeddings = embedding_model.encode(generated_texts) | |
| # Calculate cosine similarity between the user's query and each generated text | |
| similarities = util.cos_sim(query_vector, generated_embeddings) | |
| # Find the index of the most similar generated recipe | |
| best_recipe_index = np.argmax(similarities) | |
| best_generated_recipe = generated_texts[best_recipe_index] | |
| return results[0], results[1], results[2], best_generated_recipe | |
| # --- 4. GRADIO USER INTERFACE --- | |
| def format_recipe(recipe): | |
| # Formats the recommended recipes with markdown | |
| if not recipe or not recipe['title']: | |
| return "### No recipe found." | |
| return f"### {recipe['title']}\n**Ingredients:**\n{recipe['ingredients']}\n\n**Directions:**\n{recipe['directions']}" | |
| def format_generated_recipe(recipe_text): | |
| # Formats the AI-generated recipe as simple text, without extra markdown | |
| return recipe_text | |
| def recipe_wizard(ingredients): | |
| rec1, rec2, rec3, gen_rec_text = get_recommendations_and_generate(ingredients) | |
| return format_recipe(rec1), format_recipe(rec2), format_recipe(rec3), format_generated_recipe(gen_rec_text) | |
| end_time = time.time() | |
| print(f"Models and data loaded in {end_time - start_time:.2f} seconds.") | |
| # Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🍳 RecipeWizard AI") | |
| gr.Markdown("Enter the ingredients you have, and get recipe recommendations plus a new AI-generated idea!") | |
| with gr.Row(): | |
| ingredient_input = gr.Textbox(label="Your Ingredients", placeholder="e.g., chicken, rice, tomatoes, garlic") | |
| submit_btn = gr.Button("Get Recipes") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Recommended Recipes") | |
| output_rec1 = gr.Markdown() | |
| output_rec2 = gr.Markdown() | |
| output_rec3 = gr.Markdown() | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ✨ AI-Generated Idea") | |
| output_gen = gr.Textbox(label="AI Generated Recipe", lines=15) # Changed to Textbox for plain text | |
| submit_btn.click( | |
| fn=recipe_wizard, | |
| inputs=ingredient_input, | |
| outputs=[output_rec1, output_rec2, output_rec3, output_gen] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["chicken, broccoli, cheese"], | |
| ["ground beef, potatoes, onions"], | |
| ["flour, sugar, eggs, butter"] | |
| ], | |
| inputs=ingredient_input | |
| ) | |
| demo.launch() |