File size: 3,053 Bytes
20474bb
9629738
20474bb
a581271
20474bb
9629738
20474bb
a581271
20474bb
 
 
 
 
 
800a3a3
a581271
8f251c3
 
 
 
 
 
a581271
8f251c3
a581271
 
8f251c3
 
 
 
 
 
4b53dca
9629738
 
 
 
 
8f251c3
9629738
 
 
 
 
 
 
 
2952ffd
 
9629738
 
 
8f251c3
800a3a3
 
 
 
 
a581271
 
 
 
 
 
 
 
800a3a3
0017612
 
a581271
e476843
 
800a3a3
a581271
20474bb
 
d93ffc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import json

import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from format.format_output import format_output
from validate.validate_ingredients import validate_ingredients
from device.get_device_id import get_device_id

tokenizer = AutoTokenizer.from_pretrained("Ashikan/dut-recipe-generator")
model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())


@spaces.GPU
def perform_model_inference(ingredients_list):
    for ingredient_index in range(len(ingredients_list)):
        ingredients_list[ingredient_index] = ingredients_list[ingredient_index].strip()

    input_text = '{"prompt": ' + json.dumps(ingredients_list)

    output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]

    return format_output(output)


def generate_recipe(ingredients):
    ingredients_list = ingredients.lower().split(',')
    is_ingredients_valid = validate_ingredients(ingredients_list)

    if is_ingredients_valid:
        generated_text = perform_model_inference(ingredients_list)

        return {
            generated_recipe: gr.Markdown(value=generated_text, label="Generated Recipe",
                                          elem_id="recipe-container", visible=True)
        }
    else:
        error_text = "## Invalid ingredients. Please include at least 2 ingredients in a comma separated list. e.g. brown rice, onions, garlic"

        return {
            generated_recipe: gr.Markdown(value=error_text, elem_id="recipe-container", visible=True)
        }


with gr.Blocks(css="./css/styles.css") as recipegen:
    #gr.Image("./assets/dut.png", interactive=False, show_share_button=False, show_download_button=False,
             #show_fullscreen_button=False, show_label=False, elem_id="dut-logo", height=256)
    gr.Markdown("#  Recipe Generator", elem_id="header")
    gr.Markdown("### An AI Model Attempting To Produce Healthier, Diabetic-Friendly Recipes",
                elem_id="header-sub-heading")
    gr.Markdown("Start by entering a comma-separated list of ingredients below.", elem_id="header-instructions")
    with gr.Column() as column:
        user_ingredients = gr.Textbox(label="Ingredients", autofocus=True, max_lines=1, elem_id="ingredients-input")
        generate_button = gr.Button(value="Generate")
    with gr.Column():
        generated_recipe = gr.Markdown(visible=True)
    examples = gr.Examples(
        elem_id="examples",
        examples=[
            "sweet potato, mushrooms, cheese, garlic",
            "chicken breast, chili, onion, tomato, parmesan cheese",
            "strawberries, vanilla, honey, rolled oats, almonds, butter",
            "hake, spring onion, lemon"
        ],
        inputs=[user_ingredients]
    )

    generate_button.click(
        fn=generate_recipe,
        inputs=[user_ingredients],
        outputs=[generated_recipe]
    )

recipegen.launch(share=True)