|
|
import gradio as gr |
|
|
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel |
|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
import json |
|
|
|
|
|
|
|
|
bert_model_name = "alexdseo/RecipeBERT" |
|
|
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) |
|
|
bert_model = AutoModel.from_pretrained(bert_model_name) |
|
|
bert_model.eval() |
|
|
|
|
|
|
|
|
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" |
|
|
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) |
|
|
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) |
|
|
|
|
|
|
|
|
special_tokens = t5_tokenizer.all_special_tokens |
|
|
tokens_map = { |
|
|
"<sep>": "--", |
|
|
"<section>": "\n" |
|
|
} |
|
|
|
|
|
def get_embedding(text): |
|
|
"""Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens""" |
|
|
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = bert_model(**inputs) |
|
|
|
|
|
|
|
|
attention_mask = inputs['attention_mask'] |
|
|
token_embeddings = outputs.last_hidden_state |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
return (sum_embeddings / sum_mask).squeeze(0) |
|
|
|
|
|
def average_embedding(embedding_list): |
|
|
"""Berechnet den Durchschnitt einer Liste von Embeddings""" |
|
|
tensors = torch.stack([emb for _, emb in embedding_list]) |
|
|
return tensors.mean(dim=0) |
|
|
|
|
|
def get_cosine_similarity(vec1, vec2): |
|
|
"""Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren""" |
|
|
if torch.is_tensor(vec1): |
|
|
vec1 = vec1.detach().numpy() |
|
|
if torch.is_tensor(vec2): |
|
|
vec2 = vec2.detach().numpy() |
|
|
|
|
|
|
|
|
vec1 = vec1.flatten() |
|
|
vec2 = vec2.flatten() |
|
|
|
|
|
dot_product = np.dot(vec1, vec2) |
|
|
norm_a = np.linalg.norm(vec1) |
|
|
norm_b = np.linalg.norm(vec2) |
|
|
|
|
|
|
|
|
if norm_a == 0 or norm_b == 0: |
|
|
return 0 |
|
|
|
|
|
return dot_product / (norm_a * norm_b) |
|
|
|
|
|
def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6): |
|
|
"""Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten""" |
|
|
results = [] |
|
|
|
|
|
for name, emb in embedding_list: |
|
|
|
|
|
avg_similarity = get_cosine_similarity(query_vector, emb) |
|
|
|
|
|
|
|
|
individual_similarities = [get_cosine_similarity(good_emb, emb) |
|
|
for _, good_emb in all_good_embeddings] |
|
|
avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) |
|
|
|
|
|
|
|
|
combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity |
|
|
|
|
|
results.append((name, emb, combined_score)) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[2], reverse=True) |
|
|
return results |
|
|
|
|
|
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): |
|
|
""" |
|
|
Findet die besten Zutaten basierend auf RecipeBERT Embeddings. |
|
|
""" |
|
|
|
|
|
required_ingredients = list(set(required_ingredients)) |
|
|
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) |
|
|
|
|
|
|
|
|
if not required_ingredients and available_ingredients: |
|
|
random_ingredient = random.choice(available_ingredients) |
|
|
required_ingredients = [random_ingredient] |
|
|
available_ingredients = [i for i in available_ingredients if i != random_ingredient] |
|
|
|
|
|
|
|
|
|
|
|
if not required_ingredients or len(required_ingredients) >= max_ingredients: |
|
|
return required_ingredients[:max_ingredients] |
|
|
|
|
|
|
|
|
if not available_ingredients: |
|
|
return required_ingredients |
|
|
|
|
|
|
|
|
embed_required = [(e, get_embedding(e)) for e in required_ingredients] |
|
|
embed_available = [(e, get_embedding(e)) for e in available_ingredients] |
|
|
|
|
|
|
|
|
num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients)) |
|
|
|
|
|
|
|
|
final_ingredients = embed_required.copy() |
|
|
|
|
|
|
|
|
for _ in range(num_to_add): |
|
|
|
|
|
avg = average_embedding(final_ingredients) |
|
|
|
|
|
|
|
|
candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight) |
|
|
|
|
|
|
|
|
if not candidates: |
|
|
break |
|
|
|
|
|
|
|
|
best_name, best_embedding, _ = candidates[0] |
|
|
|
|
|
|
|
|
final_ingredients.append((best_name, best_embedding)) |
|
|
|
|
|
|
|
|
embed_available = [item for item in embed_available if item[0] != best_name] |
|
|
|
|
|
|
|
|
return [name for name, _ in final_ingredients] |
|
|
|
|
|
def skip_special_tokens(text, special_tokens): |
|
|
"""Entfernt spezielle Tokens aus dem Text""" |
|
|
for token in special_tokens: |
|
|
text = text.replace(token, "") |
|
|
return text |
|
|
|
|
|
def target_postprocessing(texts, special_tokens): |
|
|
"""Post-processed generierten Text""" |
|
|
if not isinstance(texts, list): |
|
|
texts = [texts] |
|
|
|
|
|
new_texts = [] |
|
|
for text in texts: |
|
|
text = skip_special_tokens(text, special_tokens) |
|
|
|
|
|
for k, v in tokens_map.items(): |
|
|
text = text.replace(k, v) |
|
|
|
|
|
new_texts.append(text) |
|
|
|
|
|
return new_texts |
|
|
|
|
|
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): |
|
|
""" |
|
|
Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält. |
|
|
""" |
|
|
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) |
|
|
expected_count = len(expected_ingredients) |
|
|
return abs(recipe_count - expected_count) == tolerance |
|
|
|
|
|
def generate_recipe_with_t5(ingredients_list, max_retries=5): |
|
|
"""Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung.""" |
|
|
original_ingredients = ingredients_list.copy() |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
if attempt > 0: |
|
|
current_ingredients = original_ingredients.copy() |
|
|
random.shuffle(current_ingredients) |
|
|
else: |
|
|
current_ingredients = ingredients_list |
|
|
|
|
|
|
|
|
ingredients_string = ", ".join(current_ingredients) |
|
|
prefix = "items: " |
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"max_length": 512, |
|
|
"min_length": 64, |
|
|
"do_sample": True, |
|
|
"top_k": 60, |
|
|
"top_p": 0.95 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
inputs = t5_tokenizer( |
|
|
prefix + ingredients_string, |
|
|
max_length=256, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="jax" |
|
|
) |
|
|
|
|
|
|
|
|
output_ids = t5_model.generate( |
|
|
input_ids=inputs.input_ids, |
|
|
attention_mask=inputs.attention_mask, |
|
|
**generation_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
generated = output_ids.sequences |
|
|
generated_text = target_postprocessing( |
|
|
t5_tokenizer.batch_decode(generated, skip_special_tokens=False), |
|
|
special_tokens |
|
|
)[0] |
|
|
|
|
|
|
|
|
recipe = {} |
|
|
sections = generated_text.split("\n") |
|
|
for section in sections: |
|
|
section = section.strip() |
|
|
if section.startswith("title:"): |
|
|
recipe["title"] = section.replace("title:", "").strip().capitalize() |
|
|
elif section.startswith("ingredients:"): |
|
|
ingredients_text = section.replace("ingredients:", "").strip() |
|
|
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] |
|
|
elif section.startswith("directions:"): |
|
|
directions_text = section.replace("directions:", "").strip() |
|
|
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] |
|
|
|
|
|
|
|
|
if "title" not in recipe: |
|
|
recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}" |
|
|
|
|
|
|
|
|
if "ingredients" not in recipe: |
|
|
recipe["ingredients"] = current_ingredients |
|
|
if "directions" not in recipe: |
|
|
recipe["directions"] = ["Keine Anweisungen generiert"] |
|
|
|
|
|
|
|
|
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): |
|
|
|
|
|
return recipe |
|
|
else: |
|
|
|
|
|
if attempt == max_retries - 1: |
|
|
|
|
|
return recipe |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if attempt == max_retries - 1: |
|
|
return { |
|
|
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", |
|
|
"ingredients": original_ingredients, |
|
|
"directions": ["Fehler beim Generieren der Rezeptanweisungen"] |
|
|
} |
|
|
|
|
|
|
|
|
return { |
|
|
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", |
|
|
"ingredients": original_ingredients, |
|
|
"directions": ["Fehler beim Generieren der Rezeptanweisungen"] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries): |
|
|
""" |
|
|
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. |
|
|
Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden. |
|
|
""" |
|
|
if not required_ingredients and not available_ingredients: |
|
|
return {"error": "Keine Zutaten angegeben"} |
|
|
|
|
|
try: |
|
|
|
|
|
optimized_ingredients = find_best_ingredients( |
|
|
required_ingredients, |
|
|
available_ingredients, |
|
|
max_ingredients |
|
|
) |
|
|
|
|
|
|
|
|
recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) |
|
|
|
|
|
|
|
|
result = { |
|
|
'title': recipe['title'], |
|
|
'ingredients': recipe['ingredients'], |
|
|
'directions': recipe['directions'], |
|
|
'used_ingredients': optimized_ingredients |
|
|
} |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} |
|
|
|
|
|
def flutter_api_generate_recipe(ingredients_data: str): |
|
|
""" |
|
|
Diese Funktion wird vom 'hugging_face_chat_gradio'-Paket über die API aufgerufen. |
|
|
Sie erwartet einen JSON-STRING als Eingabe. |
|
|
""" |
|
|
try: |
|
|
|
|
|
data = json.loads(ingredients_data) |
|
|
|
|
|
required_ingredients = data.get('required_ingredients', []) |
|
|
available_ingredients = data.get('available_ingredients', []) |
|
|
max_ingredients = data.get('max_ingredients', 7) |
|
|
max_retries = data.get('max_retries', 5) |
|
|
|
|
|
|
|
|
result_dict = process_recipe_request_logic( |
|
|
required_ingredients, available_ingredients, max_ingredients, max_retries |
|
|
) |
|
|
return json.dumps(result_dict) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"Error in flutter_api_generate_recipe: {str(e)}") |
|
|
return json.dumps({"error": f"Internal API Error: {str(e)}"}) |
|
|
|
|
|
def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val): |
|
|
"""Gradio UI Funktion für die Web-Oberfläche""" |
|
|
try: |
|
|
required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()] |
|
|
available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()] |
|
|
|
|
|
|
|
|
result = process_recipe_request_logic( |
|
|
required_ingredients, available_ingredients, max_ingredients_val, max_retries_val |
|
|
) |
|
|
|
|
|
if 'error' in result: |
|
|
return result['error'], "", "", "" |
|
|
|
|
|
ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']]) |
|
|
directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])]) |
|
|
used_ingredients = ', '.join(result['used_ingredients']) |
|
|
|
|
|
return ( |
|
|
result['title'], |
|
|
ingredients_list, |
|
|
directions_list, |
|
|
used_ingredients |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
return f"Fehler: {str(e)}", "", "", "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="AI Rezept Generator") as demo: |
|
|
gr.Markdown("# 🍳 AI Rezept Generator") |
|
|
gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!") |
|
|
|
|
|
with gr.Tab("Web-Oberfläche"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
required_ing = gr.Textbox( |
|
|
label="Benötigte Zutaten (kommasepariert)", |
|
|
placeholder="Hähnchen, Reis, Zwiebel", |
|
|
lines=2 |
|
|
) |
|
|
available_ing = gr.Textbox( |
|
|
label="Verfügbare Zutaten (kommasepariert, optional)", |
|
|
placeholder="Knoblauch, Tomate, Pfeffer, Kräuter", |
|
|
lines=2 |
|
|
) |
|
|
max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten") |
|
|
max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche") |
|
|
|
|
|
generate_btn = gr.Button("Rezept generieren", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
title_output = gr.Textbox(label="Rezepttitel", interactive=False) |
|
|
ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False) |
|
|
directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False) |
|
|
used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=gradio_ui_generate_recipe, |
|
|
inputs=[required_ing, available_ing, max_ing, max_retries], |
|
|
outputs=[title_output, ingredients_output, directions_output, used_ingredients_output] |
|
|
) |
|
|
|
|
|
with gr.Tab("API-Test"): |
|
|
gr.Markdown("### Teste die Flutter API (via 'hugging_face_chat_gradio' Client)") |
|
|
gr.Markdown("Dieser Tab zeigt, wie die Eingabe für die 'generate_recipe_for_flutter'-API aussehen sollte.") |
|
|
|
|
|
api_input = gr.Textbox( |
|
|
label="JSON-Eingabe (für API-Aufruf)", |
|
|
placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}', |
|
|
lines=4 |
|
|
) |
|
|
api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False) |
|
|
api_test_btn = gr.Button("API testen", variant="secondary") |
|
|
|
|
|
api_test_btn.click( |
|
|
fn=flutter_api_generate_recipe, |
|
|
inputs=[api_input], |
|
|
outputs=[api_output], |
|
|
api_name="generate_recipe_for_flutter" |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'], |
|
|
['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}'] |
|
|
], |
|
|
inputs=[api_input] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|