import gradio as gr from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel import torch import numpy as np import random import json from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from pydantic import BaseModel # Lade RecipeBERT Modell (für semantische Zutat-Kombination) bert_model_name = "alexdseo/RecipeBERT" bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) bert_model = AutoModel.from_pretrained(bert_model_name) bert_model.eval() # Setze das Modell in den Evaluationsmodus # Lade T5 Rezeptgenerierungsmodell MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) # Token Mapping für die T5 Modell-Ausgabe special_tokens = t5_tokenizer.all_special_tokens tokens_map = { "": "--", "
": "\n" } def get_embedding(text): """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens""" inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = bert_model(**inputs) # Mean Pooling - Mittelwert aller Token-Embeddings attention_mask = inputs['attention_mask'] token_embeddings = outputs.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return (sum_embeddings / sum_mask).squeeze(0) def average_embedding(embedding_list): """Berechnet den Durchschnitt einer Liste von Embeddings""" tensors = torch.stack([emb for _, emb in embedding_list]) return tensors.mean(dim=0) def get_cosine_similarity(vec1, vec2): """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren""" if torch.is_tensor(vec1): vec1 = vec1.detach().numpy() if torch.is_tensor(vec2): vec2 = vec2.detach().numpy() # Stelle sicher, dass die Vektoren die richtige Form haben (flachen sie bei Bedarf ab) vec1 = vec1.flatten() vec2 = vec2.flatten() dot_product = np.dot(vec1, vec2) norm_a = np.linalg.norm(vec1) norm_b = np.linalg.norm(vec2) # Division durch Null vermeiden if norm_a == 0 or norm_b == 0: return 0 return dot_product / (norm_a * norm_b) def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6): """Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten""" results = [] for name, emb in embedding_list: # Ähnlichkeit zum Durchschnittsvektor avg_similarity = get_cosine_similarity(query_vector, emb) # Durchschnittliche Ähnlichkeit zu einzelnen Zutaten individual_similarities = [get_cosine_similarity(good_emb, emb) for _, good_emb in all_good_embeddings] avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) # Kombinierter Score (gewichteter Durchschnitt) combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity results.append((name, emb, combined_score)) # Sortiere nach kombiniertem Score (absteigend) results.sort(key=lambda x: x[2], reverse=True) return results def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): """ Findet die besten Zutaten basierend auf RecipeBERT Embeddings. """ # Stelle sicher, dass keine Duplikate in den Listen sind required_ingredients = list(set(required_ingredients)) available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) # Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten if not required_ingredients and available_ingredients: random_ingredient = random.choice(available_ingredients) required_ingredients = [random_ingredient] available_ingredients = [i for i in available_ingredients if i != random_ingredient] # print(f"Keine benötigten Zutaten angegeben. Zufällig ausgewählt: {random_ingredient}") # Wenn immer noch keine Zutaten vorhanden oder bereits maximale Kapazität erreicht ist if not required_ingredients or len(required_ingredients) >= max_ingredients: return required_ingredients[:max_ingredients] # Wenn keine zusätzlichen Zutaten verfügbar sind if not available_ingredients: return required_ingredients # Berechne Embeddings für alle Zutaten embed_required = [(e, get_embedding(e)) for e in required_ingredients] embed_available = [(e, get_embedding(e)) for e in available_ingredients] # Anzahl der hinzuzufügenden Zutaten num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients)) # Kopiere benötigte Zutaten in die endgültige Liste final_ingredients = embed_required.copy() # Füge die besten Zutaten hinzu for _ in range(num_to_add): # Berechne den Durchschnittsvektor der aktuellen Kombination avg = average_embedding(final_ingredients) # Berechne kombinierte Scores für alle Kandidaten candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight) # Wenn keine Kandidaten mehr übrig sind, breche ab if not candidates: break # Wähle die beste Zutat best_name, best_embedding, _ = candidates[0] # Füge die beste Zutat zur endgültigen Liste hinzu final_ingredients.append((best_name, best_embedding)) # Entferne die Zutat aus den verfügbaren Zutaten embed_available = [item for item in embed_available if item[0] != best_name] # Extrahiere nur die Zutatennamen return [name for name, _ in final_ingredients] def skip_special_tokens(text, special_tokens): """Entfernt spezielle Tokens aus dem Text""" for token in special_tokens: text = text.replace(token, "") return text def target_postprocessing(texts, special_tokens): """Post-processed generierten Text""" if not isinstance(texts, list): texts = [texts] new_texts = [] for text in texts: text = skip_special_tokens(text, special_tokens) for k, v in tokens_map.items(): text = text.replace(k, v) new_texts.append(text) return new_texts def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): """ Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält. """ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) expected_count = len(expected_ingredients) return abs(recipe_count - expected_count) == tolerance def generate_recipe_with_t5(ingredients_list, max_retries=5): """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung.""" original_ingredients = ingredients_list.copy() for attempt in range(max_retries): try: # Für Wiederholungsversuche nach dem ersten Versuch, mische die Zutaten if attempt > 0: current_ingredients = original_ingredients.copy() random.shuffle(current_ingredients) else: current_ingredients = ingredients_list # Formatiere Zutaten als kommaseparierten String ingredients_string = ", ".join(current_ingredients) prefix = "items: " # Generationseinstellungen generation_kwargs = { "max_length": 512, "min_length": 64, "do_sample": True, "top_k": 60, "top_p": 0.95 } # print(f"Versuch {attempt + 1}: {prefix + ingredients_string}") # Tokenisiere Eingabe inputs = t5_tokenizer( prefix + ingredients_string, max_length=256, padding="max_length", truncation=True, return_tensors="jax" ) # Generiere Text output_ids = t5_model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs ) # Dekodieren und Nachbearbeiten generated = output_ids.sequences generated_text = target_postprocessing( t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens )[0] # Abschnitte parsen recipe = {} sections = generated_text.split("\n") for section in sections: section = section.strip() if section.startswith("title:"): recipe["title"] = section.replace("title:", "").strip().capitalize() elif section.startswith("ingredients:"): ingredients_text = section.replace("ingredients:", "").strip() recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] elif section.startswith("directions:"): directions_text = section.replace("directions:", "").strip() recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] # Wenn der Titel fehlt, erstelle einen if "title" not in recipe: recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}" # Stelle sicher, dass alle Abschnitte existieren if "ingredients" not in recipe: recipe["ingredients"] = current_ingredients if "directions" not in recipe: recipe["directions"] = ["Keine Anweisungen generiert"] # Validiere das Rezept if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): # print(f"Erfolg bei Versuch {attempt + 1}: Rezept hat die richtige Anzahl von Zutaten") return recipe else: # print(f"Versuch {attempt + 1} fehlgeschlagen: Erwartet {len(original_ingredients)} Zutaten, erhalten {len(recipe['ingredients'])}") if attempt == max_retries - 1: # print("Maximale Wiederholungsversuche erreicht, letztes generiertes Rezept wird zurückgegeben") return recipe except Exception as e: # print(f"Fehler bei der Rezeptgenerierung Versuch {attempt + 1}: {str(e)}") if attempt == max_retries - 1: return { "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", "ingredients": original_ingredients, "directions": ["Fehler beim Generieren der Rezeptanweisungen"] } # Fallback (sollte nicht erreicht werden) return { "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", "ingredients": original_ingredients, "directions": ["Fehler beim Generieren der Rezeptanweisungen"] } # Diese Funktion wird von der Gradio-UI und der neuen FastAPI-Route aufgerufen. def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries): """ Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden. """ if not required_ingredients and not available_ingredients: return {"error": "Keine Zutaten angegeben"} try: # Optimale Zutaten finden optimized_ingredients = find_best_ingredients( required_ingredients, available_ingredients, max_ingredients ) # Rezept mit optimierten Zutaten generieren recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) # Ergebnis formatieren result = { 'title': recipe['title'], 'ingredients': recipe['ingredients'], 'directions': recipe['directions'], 'used_ingredients': optimized_ingredients } return result except Exception as e: return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} def flutter_api_generate_recipe(ingredients_data): """ Flutter-freundliche API-Funktion für den Gradio-API-Test-Tab. Verarbeitet JSON-String-Eingabe und gibt JSON-String-Ausgabe zurück. """ try: data = json.loads(ingredients_data) # Muss ein JSON-String sein required_ingredients = data.get('required_ingredients', []) available_ingredients = data.get('available_ingredients', []) max_ingredients = data.get('max_ingredients', 7) max_retries = data.get('max_retries', 5) # Rufe die Kernlogik auf result_dict = process_recipe_request_logic( required_ingredients, available_ingredients, max_ingredients, max_retries ) return json.dumps(result_dict) # Gibt einen JSON-String zurück except Exception as e: return json.dumps({"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}) def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val): """Gradio UI Funktion für die Web-Oberfläche""" try: required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] # Rufe die Kernlogik auf result = process_recipe_request_logic( required_ingredients, available_ingredients, max_ingredients_val, max_retries_val ) if 'error' in result: return result['error'], "", "", "" ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']]) directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])]) used_ingredients = ', '.join(result['used_ingredients']) return ( result['title'], ingredients_list, directions_list, used_ingredients ) except Exception as e: return f"Fehler: {str(e)}", "", "", "" # Erstelle die Gradio Oberfläche with gr.Blocks(title="AI Rezept Generator") as demo: gr.Markdown("# 🍳 AI Rezept Generator") gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!") with gr.Tab("Web-Oberfläche"): with gr.Row(): with gr.Column(): required_ing = gr.Textbox( label="Benötigte Zutaten (kommasepariert)", placeholder="Hähnchen, Reis, Zwiebel", lines=2 ) available_ing = gr.Textbox( label="Verfügbare Zutaten (kommasepariert, optional)", placeholder="Knoblauch, Tomate, Pfeffer, Kräuter", lines=2 ) max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten") max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche") generate_btn = gr.Button("Rezept generieren", variant="primary") with gr.Column(): title_output = gr.Textbox(label="Rezepttitel", interactive=False) ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False) directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False) used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False) generate_btn.click( fn=gradio_ui_generate_recipe, inputs=[required_ing, available_ing, max_ing, max_retries], outputs=[title_output, ingredients_output, directions_output, used_ingredients_output] ) with gr.Tab("API-Test"): gr.Markdown("### Teste die Gradio API (für interne Tests)") gr.Markdown("Dieser Tab verwendet die interne Gradio API für Testzwecke.") api_input = gr.Textbox( label="JSON-Eingabe (Flutter API-Format)", placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}', lines=4 ) api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False) api_test_btn = gr.Button("API testen", variant="secondary") # Hier wird die Funktion weiterhin für den Gradio-eigenen API-Test-Tab verwendet. api_test_btn.click( fn=flutter_api_generate_recipe, inputs=[api_input], outputs=[api_output], api_name="generate_recipe_for_flutter" # Gradio-interner API-Name ) gr.Examples( examples=[ ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'], ['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}'] ], inputs=[api_input] ) # --- FastAPI-Integration --- app = FastAPI() class RecipeRequest(BaseModel): required_ingredients: list[str] = [] available_ingredients: list[str] = [] max_ingredients: int = 7 max_retries: int = 5 @app.post("/api/generate_recipe_rest") async def generate_recipe_rest_api(request_data: RecipeRequest): """ Standard-REST-API-Endpunkt für die Flutter-App. Nimmt direkt JSON-Daten an und gibt direkt JSON zurück. """ required_ingredients = request_data.required_ingredients available_ingredients = request_data.available_ingredients max_ingredients = request_data.max_ingredients max_retries = request_data.max_retries # Abwärtskompatibilität, falls 'ingredients' statt 'required_ingredients' gesendet wird # Dies ist in der FastAPI-Pydantic-Modelldefinition nicht direkt abbildbar, # aber du könntest es manuell hinzufügen, falls nötig, wenn das Pydantic-Modell flexibler wäre. # Für den Einfachheit halber gehen wir davon aus, dass Flutter die korrekten Felder sendet. result_dict = process_recipe_request_logic( required_ingredients, available_ingredients, max_ingredients, max_retries ) return JSONResponse(content=result_dict) # Gradio-App als Sub-App in die FastAPI-App mounten # Dies ist der Standardweg, um Gradio in eine FastAPI-Anwendung einzubetten. # Der Gradio-Teil wird dann unter dem Wurzelpfad '/'. app = gr.mount_gradio_app(app, demo, path="/") # Gradio unter dem Wurzelpfad mounten # Wenn du deine App lokal ausführst, kannst du FastAPI mit Uvicorn starten: # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=8000) # Für Hugging Face Spaces ist der if __name__ == "__main__": Block nicht nötig, # da Spaces Uvicorn automatisch startet und die "app"-Variable sucht.