Spaces:
Sleeping
Sleeping
| from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel | |
| import torch | |
| import numpy as np | |
| import random | |
| import json | |
| from fastapi import FastAPI | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from datetime import datetime, timedelta # Importieren für Datumsberechnungen | |
| # Lade RecipeBERT Modell (für semantische Zutat-Kombination) | |
| bert_model_name = "alexdseo/RecipeBERT" | |
| bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
| bert_model = AutoModel.from_pretrained(bert_model_name) | |
| bert_model.eval() # Setze das Modell in den Evaluationsmodus | |
| # Lade T5 Rezeptgenerierungsmodell | |
| MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" | |
| t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | |
| t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) | |
| # Token Mapping für die T5 Modell-Ausgabe | |
| special_tokens = t5_tokenizer.all_special_tokens | |
| tokens_map = { | |
| "<sep>": "--", | |
| "<section>": "\n" | |
| } | |
| # --- RecipeBERT-spezifische Funktionen --- | |
| def get_embedding(text): | |
| """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens""" | |
| inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = bert_model(**inputs) | |
| attention_mask = inputs['attention_mask'] | |
| token_embeddings = outputs.last_hidden_state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| return (sum_embeddings / sum_mask).squeeze(0) | |
| def average_embedding(embedding_list): | |
| """Berechnet den Durchschnitt einer Liste von Embeddings""" | |
| tensors = torch.stack([emb for _, emb in embedding_list]) | |
| return tensors.mean(dim=0) | |
| def get_cosine_similarity(vec1, vec2): | |
| """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren""" | |
| if torch.is_tensor(vec1): | |
| vec1 = vec1.detach().numpy() | |
| if torch.is_tensor(vec2): | |
| vec2 = vec2.detach().numpy() | |
| vec1 = vec1.flatten() | |
| vec2 = vec2.flatten() | |
| dot_product = np.dot(vec1, vec2) | |
| norm_a = np.linalg.norm(vec1) | |
| norm_b = np.linalg.norm(vec2) | |
| if norm_a == 0 or norm_b == 0: | |
| return 0 | |
| return dot_product / (norm_a * norm_b) | |
| # NEUE FUNKTION: Berechnet den Altersbonus für eine Zutat | |
| def calculate_age_bonus(date_added_str: str, category: str) -> float: | |
| """ | |
| Berechnet einen prozentualen Bonus basierend auf dem Alter der Zutat. | |
| - Standard: 0.5% pro Tag, max. 10%. | |
| - Gemüse: 2.0% pro Tag, max. 10%. | |
| """ | |
| try: | |
| date_added = datetime.fromisoformat(date_added_str.replace('Z', '+00:00')) # Handle 'Z' for UTC | |
| except ValueError: | |
| print(f"Warning: Could not parse date_added_str: {date_added_str}. Returning 0 bonus.") | |
| return 0.0 | |
| today = datetime.now() | |
| days_since_added = (today - date_added).days | |
| if days_since_added < 0: # Zutat aus der Zukunft? Ungültig. | |
| return 0.0 | |
| if category and category.lower() == "vegetables": | |
| daily_bonus = 0.02 # 2% pro Tag für Gemüse | |
| else: | |
| daily_bonus = 0.005 # 0.5% pro Tag für andere | |
| bonus = days_since_added * daily_bonus | |
| return min(bonus, 0.10) # Max 10% (0.10) | |
| def get_combined_scores(query_vector, embedding_list_with_details, all_good_embeddings, avg_weight=0.6): | |
| """ | |
| Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten. | |
| Jetzt inklusive Altersbonus. | |
| embedding_list_with_details: Liste von Tupeln (Name, Embedding, DateAddedStr, Category) | |
| """ | |
| results = [] | |
| for name, emb, date_added_str, category in embedding_list_with_details: | |
| avg_similarity = get_cosine_similarity(query_vector, emb) | |
| individual_similarities = [get_cosine_similarity(good_emb, emb) | |
| for _, good_emb in all_good_embeddings] | |
| avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) if individual_similarities else 0 | |
| base_combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity | |
| # NEU: Altersbonus hinzufügen | |
| age_bonus = calculate_age_bonus(date_added_str, category) | |
| final_combined_score = base_combined_score + age_bonus | |
| results.append((name, emb, final_combined_score, date_added_str, category)) # Behalte Details für Debug oder zukünftige Nutzung | |
| results.sort(key=lambda x: x[2], reverse=True) | |
| return results | |
| def find_best_ingredients(required_ingredients_names, available_ingredients_details, max_ingredients=6, avg_weight=0.6): | |
| """ | |
| Findet die besten Zutaten basierend auf RecipeBERT Embeddings, jetzt mit Alters- und Kategorie-Bonus. | |
| required_ingredients_names: Liste von Strings (nur Namen) | |
| available_ingredients_details: Liste von Dicts (Name, DateAdded, Category) | |
| """ | |
| required_ingredients_names = list(set(required_ingredients_names)) | |
| # Filtern der verfügbaren Zutaten, um sicherzustellen, dass keine Pflichtzutaten dabei sind | |
| # und gleichzeitig die Details beibehalten | |
| available_ingredients_filtered_details = [ | |
| item for item in available_ingredients_details | |
| if item['name'] not in required_ingredients_names | |
| ] | |
| # Wenn keine Pflichtzutaten vorhanden sind, aber verfügbare, wähle eine zufällig als Pflichtzutat | |
| if not required_ingredients_names and available_ingredients_filtered_details: | |
| random_item = random.choice(available_ingredients_filtered_details) | |
| required_ingredients_names = [random_item['name']] | |
| # Entferne die zufällig gewählte Zutat aus den verfügbaren Details | |
| available_ingredients_filtered_details = [ | |
| item for item in available_ingredients_filtered_details | |
| if item['name'] != random_item['name'] | |
| ] | |
| print(f"No required ingredients provided. Randomly selected: {required_ingredients_names[0]}") | |
| if not required_ingredients_names or len(required_ingredients_names) >= max_ingredients: | |
| return required_ingredients_names[:max_ingredients] | |
| if not available_ingredients_filtered_details: | |
| return required_ingredients_names | |
| # Erstelle Embeddings für Pflichtzutaten (nur Name und Embedding) | |
| embed_required = [(name, get_embedding(name)) for name in required_ingredients_names] | |
| # Erstelle Embeddings für verfügbare Zutaten, inklusive ihrer Details | |
| embed_available_with_details = [ | |
| (item['name'], get_embedding(item['name']), item['dateAdded'], item['category']) | |
| for item in available_ingredients_filtered_details | |
| ] | |
| num_to_add = min(max_ingredients - len(required_ingredients_names), len(embed_available_with_details)) | |
| final_ingredients_with_embeddings = embed_required.copy() # (Name, Embedding) | |
| final_ingredients_names = required_ingredients_names.copy() # Nur Namen zum Tracken der ausgewählten | |
| for _ in range(num_to_add): | |
| avg = average_embedding(final_ingredients_with_embeddings) | |
| # Sende die Liste mit den detaillierten Zutaten an get_combined_scores | |
| candidates = get_combined_scores(avg, embed_available_with_details, final_ingredients_with_embeddings, avg_weight) | |
| if not candidates: | |
| break | |
| best_name, best_embedding, best_score, _, _ = candidates[0] # Holen Sie den besten Kandidaten | |
| # Füge nur den Namen und das Embedding zum final_ingredients_with_embeddings hinzu | |
| final_ingredients_with_embeddings.append((best_name, best_embedding)) | |
| final_ingredients_names.append(best_name) | |
| # Entferne den besten Kandidaten aus den verfügbaren | |
| embed_available_with_details = [item for item in embed_available_with_details if item[0] != best_name] | |
| return final_ingredients_names | |
| def skip_special_tokens(text, special_tokens): | |
| """Entfernt spezielle Tokens aus dem Text""" | |
| for token in special_tokens: | |
| text = text.replace(token, "") | |
| return text | |
| def target_postprocessing(texts, special_tokens): | |
| """Post-processed generierten Text""" | |
| if not isinstance(texts, list): | |
| texts = [texts] | |
| new_texts = [] | |
| for text in texts: | |
| text = skip_special_tokens(text, special_tokens) | |
| for k, v in tokens_map.items(): | |
| text = text.replace(k, v) | |
| new_texts.append(text) | |
| return new_texts | |
| def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): | |
| """ | |
| Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält. | |
| """ | |
| recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) | |
| expected_count = len(expected_ingredients) | |
| return abs(recipe_count - expected_count) == tolerance | |
| def generate_recipe_with_t5(ingredients_list, max_retries=5): | |
| """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung.""" | |
| original_ingredients = ingredients_list.copy() | |
| for attempt in range(max_retries): | |
| try: | |
| if attempt > 0: | |
| current_ingredients = original_ingredients.copy() | |
| random.shuffle(current_ingredients) | |
| else: | |
| current_ingredients = ingredients_list | |
| ingredients_string = ", ".join(current_ingredients) | |
| prefix = "items: " | |
| generation_kwargs = { | |
| "max_length": 512, | |
| "min_length": 64, | |
| "do_sample": True, | |
| "top_k": 60, | |
| "top_p": 0.95 | |
| } | |
| print(f"Attempt {attempt + 1}: {prefix + ingredients_string}") # Debug-Print | |
| inputs = t5_tokenizer( | |
| prefix + ingredients_string, | |
| max_length=256, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="jax" | |
| ) | |
| output_ids = t5_model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| **generation_kwargs | |
| ) | |
| generated = output_ids.sequences | |
| generated_text = target_postprocessing( | |
| t5_tokenizer.batch_decode(generated, skip_special_tokens=False), | |
| special_tokens | |
| )[0] | |
| recipe = {} | |
| sections = generated_text.split("\n") | |
| for section in sections: | |
| section = section.strip() | |
| if section.startswith("title:"): | |
| recipe["title"] = section.replace("title:", "").strip().capitalize() | |
| elif section.startswith("ingredients:"): | |
| ingredients_text = section.replace("ingredients:", "").strip() | |
| recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] | |
| elif section.startswith("directions:"): | |
| directions_text = section.replace("directions:", "").strip() | |
| recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] | |
| if "title" not in recipe: | |
| recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}" | |
| if "ingredients" not in recipe: | |
| recipe["ingredients"] = current_ingredients | |
| if "directions" not in recipe: | |
| recipe["directions"] = ["Keine Anweisungen generiert"] | |
| if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): | |
| print(f"Success on attempt {attempt + 1}: Recipe has correct number of ingredients") # Debug-Print | |
| return recipe | |
| else: | |
| print(f"Attempt {attempt + 1} failed: Expected {len(original_ingredients)} ingredients, got {len(recipe['ingredients'])}") # Debug-Print | |
| if attempt == max_retries - 1: | |
| print("Max retries reached, returning last generated recipe") # Debug-Print | |
| return recipe | |
| except Exception as e: | |
| print(f"Error in recipe generation attempt {attempt + 1}: {str(e)}") # Debug-Print | |
| if attempt == max_retries - 1: | |
| return { | |
| "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", | |
| "ingredients": original_ingredients, | |
| "directions": ["Fehler beim Generieren der Rezeptanweisungen"] | |
| } | |
| return { | |
| "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", | |
| "ingredients": original_ingredients, | |
| "directions": ["Fehler beim Generieren der Rezeptanweisungen"] | |
| } | |
| def process_recipe_request_logic(required_ingredients, available_ingredients_details, max_ingredients, max_retries): | |
| """ | |
| Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. | |
| available_ingredients_details: Liste von Dicts (Name, DateAdded, Category) | |
| """ | |
| if not required_ingredients and not available_ingredients_details: | |
| return {"error": "Keine Zutaten angegeben"} | |
| try: | |
| # Die find_best_ingredients Funktion erwartet jetzt die detaillierte Liste | |
| optimized_ingredients = find_best_ingredients( | |
| required_ingredients, available_ingredients_details, max_ingredients | |
| ) | |
| recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) | |
| result = { | |
| 'title': recipe['title'], | |
| 'ingredients': recipe['ingredients'], | |
| 'directions': recipe['directions'], | |
| 'used_ingredients': optimized_ingredients | |
| } | |
| return result | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() # Dies hilft bei der Fehlersuche im Log | |
| return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} | |
| # --- FastAPI-Implementierung --- | |
| app = FastAPI(title="AI Recipe Generator API") | |
| # NEU: Model für die empfangene Zutat mit Details | |
| class IngredientDetail(BaseModel): | |
| name: str | |
| dateAdded: str # Muss ein String sein, da wir ihn als ISO 8601 empfangen | |
| category: str | |
| class RecipeRequest(BaseModel): | |
| required_ingredients: list[str] = [] | |
| # NEU: available_ingredients ist jetzt eine Liste von IngredientDetail-Objekten | |
| available_ingredients: list[IngredientDetail] = [] | |
| max_ingredients: int = 7 | |
| max_retries: int = 5 | |
| # Optional: Für Abwärtskompatibilität (kann entfernt werden, wenn nicht mehr benötigt) | |
| ingredients: list[str] = [] | |
| async def generate_recipe_api(request_data: RecipeRequest): | |
| """ | |
| Standard-REST-API-Endpunkt für die Flutter-App. | |
| Nimmt direkt JSON-Daten an und gibt direkt JSON zurück. | |
| """ | |
| final_required_ingredients = request_data.required_ingredients | |
| if not final_required_ingredients and request_data.ingredients: | |
| final_required_ingredients = request_data.ingredients | |
| # Jetzt die detaillierten available_ingredients an die Logik übergeben | |
| result_dict = process_recipe_request_logic( | |
| final_required_ingredients, | |
| request_data.available_ingredients, # Hier ist die Liste der IngredientDetail-Objekte | |
| request_data.max_ingredients, | |
| request_data.max_retries | |
| ) | |
| return JSONResponse(content=result_dict) | |
| async def read_root(): | |
| return {"message": "AI Recipe Generator API is running (FastAPI only)!"} | |
| print("INFO: Pure FastAPI application script finished execution and defined 'app' variable.") |