Spaces:
Sleeping
Sleeping
File size: 3,053 Bytes
20474bb 9629738 20474bb a581271 20474bb 9629738 20474bb a581271 20474bb 800a3a3 a581271 8f251c3 a581271 8f251c3 a581271 8f251c3 4b53dca 9629738 8f251c3 9629738 2952ffd 9629738 8f251c3 800a3a3 a581271 800a3a3 0017612 a581271 e476843 800a3a3 a581271 20474bb d93ffc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import json
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from format.format_output import format_output
from validate.validate_ingredients import validate_ingredients
from device.get_device_id import get_device_id
tokenizer = AutoTokenizer.from_pretrained("Ashikan/dut-recipe-generator")
model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
@spaces.GPU
def perform_model_inference(ingredients_list):
for ingredient_index in range(len(ingredients_list)):
ingredients_list[ingredient_index] = ingredients_list[ingredient_index].strip()
input_text = '{"prompt": ' + json.dumps(ingredients_list)
output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
return format_output(output)
def generate_recipe(ingredients):
ingredients_list = ingredients.lower().split(',')
is_ingredients_valid = validate_ingredients(ingredients_list)
if is_ingredients_valid:
generated_text = perform_model_inference(ingredients_list)
return {
generated_recipe: gr.Markdown(value=generated_text, label="Generated Recipe",
elem_id="recipe-container", visible=True)
}
else:
error_text = "## Invalid ingredients. Please include at least 2 ingredients in a comma separated list. e.g. brown rice, onions, garlic"
return {
generated_recipe: gr.Markdown(value=error_text, elem_id="recipe-container", visible=True)
}
with gr.Blocks(css="./css/styles.css") as recipegen:
#gr.Image("./assets/dut.png", interactive=False, show_share_button=False, show_download_button=False,
#show_fullscreen_button=False, show_label=False, elem_id="dut-logo", height=256)
gr.Markdown("# Recipe Generator", elem_id="header")
gr.Markdown("### An AI Model Attempting To Produce Healthier, Diabetic-Friendly Recipes",
elem_id="header-sub-heading")
gr.Markdown("Start by entering a comma-separated list of ingredients below.", elem_id="header-instructions")
with gr.Column() as column:
user_ingredients = gr.Textbox(label="Ingredients", autofocus=True, max_lines=1, elem_id="ingredients-input")
generate_button = gr.Button(value="Generate")
with gr.Column():
generated_recipe = gr.Markdown(visible=True)
examples = gr.Examples(
elem_id="examples",
examples=[
"sweet potato, mushrooms, cheese, garlic",
"chicken breast, chili, onion, tomato, parmesan cheese",
"strawberries, vanilla, honey, rolled oats, almonds, butter",
"hake, spring onion, lemon"
],
inputs=[user_ingredients]
)
generate_button.click(
fn=generate_recipe,
inputs=[user_ingredients],
outputs=[generated_recipe]
)
recipegen.launch(share=True)
|