BhavaishKumar112 commited on
Commit
0017612
·
verified ·
1 Parent(s): 8f8b1b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -22
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import json
2
-
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
-
6
  from format.format_output import format_output
7
  from validate.validate_ingredients import validate_ingredients
8
  from device.get_device_id import get_device_id
@@ -11,34 +9,56 @@ tokenizer = AutoTokenizer.from_pretrained("Ashikan/dut-recipe-generator")
11
  model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
12
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
13
 
14
- def perform_model_inference(ingredients_list):
15
- for ingredient_index in range(len(ingredients_list)):
16
- ingredients_list[ingredient_index] = ingredients_list[ingredient_index].strip()
 
17
 
18
- input_text = '{"prompt": ' + json.dumps(ingredients_list)
 
 
 
 
19
 
20
  output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
21
 
22
  return format_output(output)
23
 
24
- def chat_function(history, user_input):
25
- # Extract user ingredients
26
- ingredients_list = user_input.lower().split(',')
 
 
 
 
 
 
 
27
 
28
- # Validate the ingredients
29
- if len(ingredients_list) < 2:
30
- error_text = "Please provide at least 2 ingredients, separated by commas."
31
- history.append((user_input, error_text))
32
  return history, ""
33
 
34
- # Generate the recipe
35
- history.append((user_input, "Generating recipe..."))
36
- recipe = perform_model_inference(ingredients_list)
37
- history[-1] = (user_input, recipe) # Replace the "Generating recipe..." message with the result
38
- return history, ""
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Define the Gradio interface
41
- with gr.Blocks(css="""
42
  #chatbot-container {
43
  background-color: #f7f7f8;
44
  border-radius: 10px;
@@ -75,13 +95,21 @@ with gr.Blocks(css="""
75
  }
76
  """) as recipegen:
77
  gr.Markdown("# Recipe Generator")
78
- gr.Markdown("An AI model attempting to produce healthier, diabetic-friendly recipes. Start by entering ingredients.")
79
 
80
  chatbot = gr.Chatbot(elem_id="chatbot-container")
81
- user_input = gr.Textbox(placeholder="Enter ingredients (e.g., chicken, onions, garlic)", label="Your Input")
 
 
 
 
 
 
 
 
82
  submit_button = gr.Button("Generate")
83
 
84
  # Link the chatbot and input to the chat function
85
- submit_button.click(chat_function, inputs=[chatbot, user_input], outputs=[chatbot, user_input])
86
 
87
  recipegen.launch(share=True)
 
1
  import json
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
4
  from format.format_output import format_output
5
  from validate.validate_ingredients import validate_ingredients
6
  from device.get_device_id import get_device_id
 
9
  model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
10
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
11
 
12
+ def perform_model_inference(ingredients_list=None, recipe_name=None):
13
+ if ingredients_list:
14
+ for ingredient_index in range(len(ingredients_list)):
15
+ ingredients_list[ingredient_index] = ingredients_list[ingredient_index].strip()
16
 
17
+ input_text = '{"prompt": ' + json.dumps(ingredients_list)
18
+ elif recipe_name:
19
+ input_text = '{"prompt": "Generate ingredients and method for the recipe: ' + recipe_name + '"}'
20
+ else:
21
+ return "Invalid input"
22
 
23
  output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
24
 
25
  return format_output(output)
26
 
27
+ def chat_function(history, user_input, mode):
28
+ # If mode is "ingredients", process as ingredient list
29
+ if mode == "ingredients":
30
+ ingredients_list = user_input.lower().split(',')
31
+
32
+ # Validate the ingredients
33
+ if len(ingredients_list) < 2:
34
+ error_text = "Please provide at least 2 ingredients, separated by commas."
35
+ history.append((user_input, error_text))
36
+ return history, ""
37
 
38
+ # Generate the recipe
39
+ history.append((user_input, "Generating recipe..."))
40
+ recipe = perform_model_inference(ingredients_list=ingredients_list)
41
+ history[-1] = (user_input, recipe) # Replace the "Generating recipe..." message with the result
42
  return history, ""
43
 
44
+ # If mode is "recipe", process as recipe name
45
+ elif mode == "recipe":
46
+ recipe_name = user_input.strip()
47
+
48
+ # Validate the recipe name
49
+ if not recipe_name:
50
+ error_text = "Please provide a valid recipe name."
51
+ history.append((user_input, error_text))
52
+ return history, ""
53
+
54
+ # Generate ingredients and method
55
+ history.append((user_input, "Generating ingredients and method..."))
56
+ recipe_details = perform_model_inference(recipe_name=recipe_name)
57
+ history[-1] = (user_input, recipe_details) # Replace the "Generating ingredients..." message with the result
58
+ return history, ""
59
 
60
  # Define the Gradio interface
61
+ with gr.Blocks(css="""
62
  #chatbot-container {
63
  background-color: #f7f7f8;
64
  border-radius: 10px;
 
95
  }
96
  """) as recipegen:
97
  gr.Markdown("# Recipe Generator")
98
+ gr.Markdown("An AI model attempting to produce healthier, diabetic-friendly recipes. Start by entering ingredients or a recipe name.")
99
 
100
  chatbot = gr.Chatbot(elem_id="chatbot-container")
101
+ user_input = gr.Textbox(placeholder="Enter ingredients or recipe name", label="Your Input")
102
+
103
+ # Dropdown for selecting input mode (ingredients or recipe name)
104
+ mode_selector = gr.Dropdown(
105
+ choices=["ingredients", "recipe"],
106
+ label="Select Input Mode",
107
+ value="ingredients"
108
+ )
109
+
110
  submit_button = gr.Button("Generate")
111
 
112
  # Link the chatbot and input to the chat function
113
+ submit_button.click(chat_function, inputs=[chatbot, user_input, mode_selector], outputs=[chatbot, user_input])
114
 
115
  recipegen.launch(share=True)