File size: 17,986 Bytes
1787e4b
d71c8ab
1787e4b
 
 
d027733
1787e4b
bcc13b5
1787e4b
 
 
bcc13b5
1787e4b
bcc13b5
1787e4b
 
 
 
bcc13b5
d71c8ab
 
 
 
 
1787e4b
d71c8ab
bcc13b5
d71c8ab
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
 
 
 
bcc13b5
d71c8ab
 
 
 
bcc13b5
d71c8ab
 
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
 
 
bcc13b5
d71c8ab
 
 
 
bcc13b5
d71c8ab
 
 
 
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
 
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcc13b5
 
 
d71c8ab
 
 
 
 
bcc13b5
d71c8ab
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
 
 
 
 
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
 
 
 
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
bcc13b5
d71c8ab
 
 
 
 
 
 
 
 
 
 
 
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
bcc13b5
d71c8ab
bcc13b5
d71c8ab
 
d027733
 
88f38f2
d71c8ab
88f38f2
 
d71c8ab
88f38f2
 
bcc13b5
88f38f2
bcc13b5
1787e4b
 
bcc13b5
1787e4b
 
bcc13b5
 
1787e4b
bcc13b5
88f38f2
d71c8ab
 
bcc13b5
d71c8ab
 
 
88f38f2
bcc13b5
88f38f2
 
 
d027733
88f38f2
d027733
 
88f38f2
 
d027733
 
88f38f2
 
 
 
 
 
 
 
 
 
d027733
bcc13b5
d71c8ab
d027733
 
 
d71c8ab
bcc13b5
 
d71c8ab
 
 
bcc13b5
88f38f2
 
 
 
bcc13b5
d71c8ab
 
bcc13b5
d71c8ab
 
 
bcc13b5
1787e4b
d71c8ab
bcc13b5
1787e4b
 
 
bcc13b5
1787e4b
d027733
bcc13b5
 
 
 
 
 
 
 
d71c8ab
 
 
bcc13b5
 
d71c8ab
 
 
bcc13b5
 
d71c8ab
 
bcc13b5
 
 
 
 
d71c8ab
bcc13b5
 
 
 
 
 
 
88f38f2
bcc13b5
 
 
 
d027733
 
bcc13b5
d71c8ab
d027733
d71c8ab
 
 
bcc13b5
 
 
d71c8ab
 
 
beaa316
d027733
d71c8ab
bcc13b5
d71c8ab
 
 
 
 
 
 
1787e4b
d027733
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
import gradio as gr
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
import torch
import numpy as np
import random
import json # Beibehalten, da es in flutter_api_generate_recipe verwendet wird

# Lade RecipeBERT Modell (für semantische Zutat-Kombination)
bert_model_name = "alexdseo/RecipeBERT"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_model.eval() # Setze das Modell in den Evaluationsmodus

# Lade T5 Rezeptgenerierungsmodell
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)

# Token Mapping für die T5 Modell-Ausgabe
special_tokens = t5_tokenizer.all_special_tokens
tokens_map = {
    "<sep>": "--",
    "<section>": "\n"
}

def get_embedding(text):
    """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens"""
    inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = bert_model(**inputs)

    # Mean Pooling - Mittelwert aller Token-Embeddings
    attention_mask = inputs['attention_mask']
    token_embeddings = outputs.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    return (sum_embeddings / sum_mask).squeeze(0)

def average_embedding(embedding_list):
    """Berechnet den Durchschnitt einer Liste von Embeddings"""
    tensors = torch.stack([emb for _, emb in embedding_list])
    return tensors.mean(dim=0)

def get_cosine_similarity(vec1, vec2):
    """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren"""
    if torch.is_tensor(vec1):
        vec1 = vec1.detach().numpy()
    if torch.is_tensor(vec2):
        vec2 = vec2.detach().numpy()

    # Stelle sicher, dass die Vektoren die richtige Form haben (flachen sie bei Bedarf ab)
    vec1 = vec1.flatten()
    vec2 = vec2.flatten()

    dot_product = np.dot(vec1, vec2)
    norm_a = np.linalg.norm(vec1)
    norm_b = np.linalg.norm(vec2)

    # Division durch Null vermeiden
    if norm_a == 0 or norm_b == 0:
        return 0

    return dot_product / (norm_a * norm_b)

def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
    """Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten"""
    results = []

    for name, emb in embedding_list:
        # Ähnlichkeit zum Durchschnittsvektor
        avg_similarity = get_cosine_similarity(query_vector, emb)

        # Durchschnittliche Ähnlichkeit zu einzelnen Zutaten
        individual_similarities = [get_cosine_similarity(good_emb, emb)
                                   for _, good_emb in all_good_embeddings]
        avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)

        # Kombinierter Score (gewichteter Durchschnitt)
        combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity

        results.append((name, emb, combined_score))

    # Sortiere nach kombiniertem Score (absteigend)
    results.sort(key=lambda x: x[2], reverse=True)
    return results

def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
    """
    Findet die besten Zutaten basierend auf RecipeBERT Embeddings.
    """
    # Stelle sicher, dass keine Duplikate in den Listen sind
    required_ingredients = list(set(required_ingredients))
    available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))

    # Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten
    if not required_ingredients and available_ingredients:
        random_ingredient = random.choice(available_ingredients)
        required_ingredients = [random_ingredient]
        available_ingredients = [i for i in available_ingredients if i != random_ingredient]
        # print(f"Keine benötigten Zutaten angegeben. Zufällig ausgewählt: {random_ingredient}")

    # Wenn immer noch keine Zutaten vorhanden oder bereits maximale Kapazität erreicht ist
    if not required_ingredients or len(required_ingredients) >= max_ingredients:
        return required_ingredients[:max_ingredients]

    # Wenn keine zusätzlichen Zutaten verfügbar sind
    if not available_ingredients:
        return required_ingredients

    # Berechne Embeddings für alle Zutaten
    embed_required = [(e, get_embedding(e)) for e in required_ingredients]
    embed_available = [(e, get_embedding(e)) for e in available_ingredients]

    # Anzahl der hinzuzufügenden Zutaten
    num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))

    # Kopiere benötigte Zutaten in die endgültige Liste
    final_ingredients = embed_required.copy()

    # Füge die besten Zutaten hinzu
    for _ in range(num_to_add):
        # Berechne den Durchschnittsvektor der aktuellen Kombination
        avg = average_embedding(final_ingredients)

        # Berechne kombinierte Scores für alle Kandidaten
        candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)

        # Wenn keine Kandidaten mehr übrig sind, breche ab
        if not candidates:
            break

        # Wähle die beste Zutat
        best_name, best_embedding, _ = candidates[0]

        # Füge die beste Zutat zur endgültigen Liste hinzu
        final_ingredients.append((best_name, best_embedding))

        # Entferne die Zutat aus den verfügbaren Zutaten
        embed_available = [item for item in embed_available if item[0] != best_name]

    # Extrahiere nur die Zutatennamen
    return [name for name, _ in final_ingredients]

def skip_special_tokens(text, special_tokens):
    """Entfernt spezielle Tokens aus dem Text"""
    for token in special_tokens:
        text = text.replace(token, "")
    return text

def target_postprocessing(texts, special_tokens):
    """Post-processed generierten Text"""
    if not isinstance(texts, list):
        texts = [texts]

    new_texts = []
    for text in texts:
        text = skip_special_tokens(text, special_tokens)

        for k, v in tokens_map.items():
            text = text.replace(k, v)

        new_texts.append(text)

    return new_texts

def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
    """
    Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
    """
    recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
    expected_count = len(expected_ingredients)
    return abs(recipe_count - expected_count) == tolerance

def generate_recipe_with_t5(ingredients_list, max_retries=5):
    """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
    original_ingredients = ingredients_list.copy()

    for attempt in range(max_retries):
        try:
            # Für Wiederholungsversuche nach dem ersten Versuch, mische die Zutaten
            if attempt > 0:
                current_ingredients = original_ingredients.copy()
                random.shuffle(current_ingredients)
            else:
                current_ingredients = ingredients_list

            # Formatiere Zutaten als kommaseparierten String
            ingredients_string = ", ".join(current_ingredients)
            prefix = "items: "

            # Generationseinstellungen
            generation_kwargs = {
                "max_length": 512,
                "min_length": 64,
                "do_sample": True,
                "top_k": 60,
                "top_p": 0.95
            }
            # print(f"Versuch {attempt + 1}: {prefix + ingredients_string}")

            # Tokenisiere Eingabe
            inputs = t5_tokenizer(
                prefix + ingredients_string,
                max_length=256,
                padding="max_length",
                truncation=True,
                return_tensors="jax"
            )

            # Generiere Text
            output_ids = t5_model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                **generation_kwargs
            )

            # Dekodieren und Nachbearbeiten
            generated = output_ids.sequences
            generated_text = target_postprocessing(
                t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
                special_tokens
            )[0]

            # Abschnitte parsen
            recipe = {}
            sections = generated_text.split("\n")
            for section in sections:
                section = section.strip()
                if section.startswith("title:"):
                    recipe["title"] = section.replace("title:", "").strip().capitalize()
                elif section.startswith("ingredients:"):
                    ingredients_text = section.replace("ingredients:", "").strip()
                    recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
                elif section.startswith("directions:"):
                    directions_text = section.replace("directions:", "").strip()
                    recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]

            # Wenn der Titel fehlt, erstelle einen
            if "title" not in recipe:
                recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"

            # Stelle sicher, dass alle Abschnitte existieren
            if "ingredients" not in recipe:
                recipe["ingredients"] = current_ingredients
            if "directions" not in recipe:
                recipe["directions"] = ["Keine Anweisungen generiert"]

            # Validiere das Rezept
            if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
                # print(f"Erfolg bei Versuch {attempt + 1}: Rezept hat die richtige Anzahl von Zutaten")
                return recipe
            else:
                # print(f"Versuch {attempt + 1} fehlgeschlagen: Erwartet {len(original_ingredients)} Zutaten, erhalten {len(recipe['ingredients'])}")
                if attempt == max_retries - 1:
                    # print("Maximale Wiederholungsversuche erreicht, letztes generiertes Rezept wird zurückgegeben")
                    return recipe

        except Exception as e:
            # print(f"Fehler bei der Rezeptgenerierung Versuch {attempt + 1}: {str(e)}")
            if attempt == max_retries - 1:
                return {
                    "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
                    "ingredients": original_ingredients,
                    "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
                }

    # Fallback (sollte nicht erreicht werden)
    return {
        "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
        "ingredients": original_ingredients,
        "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
    }

# Diese Funktion wird von der Gradio-UI und der FastAPI-Route aufgerufen.
# Sie ist für die Kernlogik zuständig.
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
    """
    Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
    Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden.
    """
    if not required_ingredients and not available_ingredients:
        return {"error": "Keine Zutaten angegeben"}

    try:
        # Optimale Zutaten finden
        optimized_ingredients = find_best_ingredients(
            required_ingredients,
            available_ingredients,
            max_ingredients
        )

        # Rezept mit optimierten Zutaten generieren
        recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)

        # Ergebnis formatieren
        result = {
            'title': recipe['title'],
            'ingredients': recipe['ingredients'],
            'directions': recipe['directions'],
            'used_ingredients': optimized_ingredients
        }
        return result

    except Exception as e:
        return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}

def flutter_api_generate_recipe(ingredients_data: str): # Typ-Hint für Klarheit
    """
    Diese Funktion wird vom 'hugging_face_chat_gradio'-Paket über die API aufgerufen.
    Sie erwartet einen JSON-STRING als Eingabe.
    """
    try:
        # Der 'hugging_face_chat_gradio'-Client sendet das Payload als String.
        data = json.loads(ingredients_data)

        required_ingredients = data.get('required_ingredients', [])
        available_ingredients = data.get('available_ingredients', [])
        max_ingredients = data.get('max_ingredients', 7)
        max_retries = data.get('max_retries', 5)

        # Rufe die Kernlogik auf
        result_dict = process_recipe_request_logic(
            required_ingredients, available_ingredients, max_ingredients, max_retries
        )
        return json.dumps(result_dict) # Gibt einen JSON-STRING zurück

    except Exception as e:
        # Logge den Fehler für Debugging im Space-Log
        print(f"Error in flutter_api_generate_recipe: {str(e)}")
        return json.dumps({"error": f"Internal API Error: {str(e)}"})

def gradio_ui_generate_recipe(required_ingredients_text, available_ingredients_text, max_ingredients_val, max_retries_val):
    """Gradio UI Funktion für die Web-Oberfläche"""
    try:
        required_ingredients = [ing.strip() for ing in required_ingredients_text.split(',') if ing.strip()]
        available_ingredients = [ing.strip() for ing in available_ingredients_text.split(',') if ing.strip()]

        # Rufe die Kernlogik auf
        result = process_recipe_request_logic(
            required_ingredients, available_ingredients, max_ingredients_val, max_retries_val
        )

        if 'error' in result:
            return result['error'], "", "", ""

        ingredients_list = '\n'.join([f"• {ing}" for ing in result['ingredients']])
        directions_list = '\n'.join([f"{i+1}. {dir}" for i, dir in enumerate(result['directions'])])
        used_ingredients = ', '.join(result['used_ingredients'])

        return (
            result['title'],
            ingredients_list,
            directions_list,
            used_ingredients
        )

    except Exception as e:
        # Fehlermeldung für die Gradio UI
        return f"Fehler: {str(e)}", "", "", ""

# Erstelle die Gradio Oberfläche
with gr.Blocks(title="AI Rezept Generator") as demo:
    gr.Markdown("# 🍳 AI Rezept Generator")
    gr.Markdown("Generiere Rezepte mit KI und intelligenter Zutat-Kombination!")

    with gr.Tab("Web-Oberfläche"):
        with gr.Row():
            with gr.Column():
                required_ing = gr.Textbox(
                    label="Benötigte Zutaten (kommasepariert)",
                    placeholder="Hähnchen, Reis, Zwiebel",
                    lines=2
                )
                available_ing = gr.Textbox(
                    label="Verfügbare Zutaten (kommasepariert, optional)",
                    placeholder="Knoblauch, Tomate, Pfeffer, Kräuter",
                    lines=2
                )
                max_ing = gr.Slider(3, 10, value=7, step=1, label="Maximale Zutaten")
                max_retries = gr.Slider(1, 10, value=5, step=1, label="Max. Wiederholungsversuche")

                generate_btn = gr.Button("Rezept generieren", variant="primary")

            with gr.Column():
                title_output = gr.Textbox(label="Rezepttitel", interactive=False)
                ingredients_output = gr.Textbox(label="Zutaten", lines=8, interactive=False)
                directions_output = gr.Textbox(label="Anweisungen", lines=10, interactive=False)
                used_ingredients_output = gr.Textbox(label="Verwendete Zutaten", interactive=False)

                generate_btn.click(
                    fn=gradio_ui_generate_recipe,
                    inputs=[required_ing, available_ing, max_ing, max_retries],
                    outputs=[title_output, ingredients_output, directions_output, used_ingredients_output]
                )

    with gr.Tab("API-Test"):
        gr.Markdown("### Teste die Flutter API (via 'hugging_face_chat_gradio' Client)")
        gr.Markdown("Dieser Tab zeigt, wie die Eingabe für die 'generate_recipe_for_flutter'-API aussehen sollte.")

        api_input = gr.Textbox(
            label="JSON-Eingabe (für API-Aufruf)",
            placeholder='{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic"], "max_ingredients": 6}',
            lines=4
        )
        api_output = gr.Textbox(label="JSON-Ausgabe", lines=15, interactive=False)
        api_test_btn = gr.Button("API testen", variant="secondary")

        api_test_btn.click(
            fn=flutter_api_generate_recipe,
            inputs=[api_input],
            outputs=[api_output],
            api_name="generate_recipe_for_flutter" # Dies ist der api_name, den das Flutter-Paket verwendet
        )

        gr.Examples(
            examples=[
                ['{"required_ingredients": ["chicken", "rice"], "available_ingredients": ["onion", "garlic", "tomato"], "max_ingredients": 6}'],
                ['{"ingredients": ["pasta"], "available_ingredients": ["cheese", "mushrooms", "cream"], "max_ingredients": 5}']
            ],
            inputs=[api_input]
        )

# Gradio-App starten
if __name__ == "__main__":
    demo.launch()