Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
-
|
| 3 |
import gradio as gr
|
| 4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 5 |
-
|
| 6 |
from format.format_output import format_output
|
| 7 |
from validate.validate_ingredients import validate_ingredients
|
| 8 |
from device.get_device_id import get_device_id
|
|
@@ -11,34 +9,56 @@ tokenizer = AutoTokenizer.from_pretrained("Ashikan/dut-recipe-generator")
|
|
| 11 |
model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
|
| 12 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
|
| 13 |
|
| 14 |
-
def perform_model_inference(ingredients_list):
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
|
| 21 |
|
| 22 |
return format_output(output)
|
| 23 |
|
| 24 |
-
def chat_function(history, user_input):
|
| 25 |
-
#
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
history
|
| 32 |
return history, ""
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Define the Gradio interface
|
| 41 |
-
with gr.Blocks(css="""
|
| 42 |
#chatbot-container {
|
| 43 |
background-color: #f7f7f8;
|
| 44 |
border-radius: 10px;
|
|
@@ -75,13 +95,21 @@ with gr.Blocks(css="""
|
|
| 75 |
}
|
| 76 |
""") as recipegen:
|
| 77 |
gr.Markdown("# Recipe Generator")
|
| 78 |
-
gr.Markdown("An AI model attempting to produce healthier, diabetic-friendly recipes. Start by entering ingredients.")
|
| 79 |
|
| 80 |
chatbot = gr.Chatbot(elem_id="chatbot-container")
|
| 81 |
-
user_input = gr.Textbox(placeholder="Enter ingredients
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
submit_button = gr.Button("Generate")
|
| 83 |
|
| 84 |
# Link the chatbot and input to the chat function
|
| 85 |
-
submit_button.click(chat_function, inputs=[chatbot, user_input], outputs=[chatbot, user_input])
|
| 86 |
|
| 87 |
recipegen.launch(share=True)
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
|
|
| 4 |
from format.format_output import format_output
|
| 5 |
from validate.validate_ingredients import validate_ingredients
|
| 6 |
from device.get_device_id import get_device_id
|
|
|
|
| 9 |
model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
|
| 10 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
|
| 11 |
|
| 12 |
+
def perform_model_inference(ingredients_list=None, recipe_name=None):
|
| 13 |
+
if ingredients_list:
|
| 14 |
+
for ingredient_index in range(len(ingredients_list)):
|
| 15 |
+
ingredients_list[ingredient_index] = ingredients_list[ingredient_index].strip()
|
| 16 |
|
| 17 |
+
input_text = '{"prompt": ' + json.dumps(ingredients_list)
|
| 18 |
+
elif recipe_name:
|
| 19 |
+
input_text = '{"prompt": "Generate ingredients and method for the recipe: ' + recipe_name + '"}'
|
| 20 |
+
else:
|
| 21 |
+
return "Invalid input"
|
| 22 |
|
| 23 |
output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
|
| 24 |
|
| 25 |
return format_output(output)
|
| 26 |
|
| 27 |
+
def chat_function(history, user_input, mode):
|
| 28 |
+
# If mode is "ingredients", process as ingredient list
|
| 29 |
+
if mode == "ingredients":
|
| 30 |
+
ingredients_list = user_input.lower().split(',')
|
| 31 |
+
|
| 32 |
+
# Validate the ingredients
|
| 33 |
+
if len(ingredients_list) < 2:
|
| 34 |
+
error_text = "Please provide at least 2 ingredients, separated by commas."
|
| 35 |
+
history.append((user_input, error_text))
|
| 36 |
+
return history, ""
|
| 37 |
|
| 38 |
+
# Generate the recipe
|
| 39 |
+
history.append((user_input, "Generating recipe..."))
|
| 40 |
+
recipe = perform_model_inference(ingredients_list=ingredients_list)
|
| 41 |
+
history[-1] = (user_input, recipe) # Replace the "Generating recipe..." message with the result
|
| 42 |
return history, ""
|
| 43 |
|
| 44 |
+
# If mode is "recipe", process as recipe name
|
| 45 |
+
elif mode == "recipe":
|
| 46 |
+
recipe_name = user_input.strip()
|
| 47 |
+
|
| 48 |
+
# Validate the recipe name
|
| 49 |
+
if not recipe_name:
|
| 50 |
+
error_text = "Please provide a valid recipe name."
|
| 51 |
+
history.append((user_input, error_text))
|
| 52 |
+
return history, ""
|
| 53 |
+
|
| 54 |
+
# Generate ingredients and method
|
| 55 |
+
history.append((user_input, "Generating ingredients and method..."))
|
| 56 |
+
recipe_details = perform_model_inference(recipe_name=recipe_name)
|
| 57 |
+
history[-1] = (user_input, recipe_details) # Replace the "Generating ingredients..." message with the result
|
| 58 |
+
return history, ""
|
| 59 |
|
| 60 |
# Define the Gradio interface
|
| 61 |
+
with gr.Blocks(css="""
|
| 62 |
#chatbot-container {
|
| 63 |
background-color: #f7f7f8;
|
| 64 |
border-radius: 10px;
|
|
|
|
| 95 |
}
|
| 96 |
""") as recipegen:
|
| 97 |
gr.Markdown("# Recipe Generator")
|
| 98 |
+
gr.Markdown("An AI model attempting to produce healthier, diabetic-friendly recipes. Start by entering ingredients or a recipe name.")
|
| 99 |
|
| 100 |
chatbot = gr.Chatbot(elem_id="chatbot-container")
|
| 101 |
+
user_input = gr.Textbox(placeholder="Enter ingredients or recipe name", label="Your Input")
|
| 102 |
+
|
| 103 |
+
# Dropdown for selecting input mode (ingredients or recipe name)
|
| 104 |
+
mode_selector = gr.Dropdown(
|
| 105 |
+
choices=["ingredients", "recipe"],
|
| 106 |
+
label="Select Input Mode",
|
| 107 |
+
value="ingredients"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
submit_button = gr.Button("Generate")
|
| 111 |
|
| 112 |
# Link the chatbot and input to the chat function
|
| 113 |
+
submit_button.click(chat_function, inputs=[chatbot, user_input, mode_selector], outputs=[chatbot, user_input])
|
| 114 |
|
| 115 |
recipegen.launch(share=True)
|