BhavaishKumar112 commited on
Commit
8f8b1b5
·
verified ·
1 Parent(s): a64f1cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
 
@@ -6,18 +7,18 @@ from format.format_output import format_output
6
  from validate.validate_ingredients import validate_ingredients
7
  from device.get_device_id import get_device_id
8
 
9
- # Initialize the model and pipeline
10
  tokenizer = AutoTokenizer.from_pretrained("Ashikan/dut-recipe-generator")
11
  model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
12
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
13
 
14
  def perform_model_inference(ingredients_list):
15
- # Clean up the ingredients
16
- ingredients_list = [ingredient.strip() for ingredient in ingredients_list]
 
17
  input_text = '{"prompt": ' + json.dumps(ingredients_list)
18
 
19
- # Perform model inference
20
  output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
 
21
  return format_output(output)
22
 
23
  def chat_function(history, user_input):
 
1
  import json
2
+
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
 
 
7
  from validate.validate_ingredients import validate_ingredients
8
  from device.get_device_id import get_device_id
9
 
 
10
  tokenizer = AutoTokenizer.from_pretrained("Ashikan/dut-recipe-generator")
11
  model = AutoModelForCausalLM.from_pretrained("Ashikan/dut-recipe-generator")
12
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=get_device_id())
13
 
14
  def perform_model_inference(ingredients_list):
15
+ for ingredient_index in range(len(ingredients_list)):
16
+ ingredients_list[ingredient_index] = ingredients_list[ingredient_index].strip()
17
+
18
  input_text = '{"prompt": ' + json.dumps(ingredients_list)
19
 
 
20
  output = pipe(input_text, max_length=1024, temperature=0.1, do_sample=True, truncation=True)[0]["generated_text"]
21
+
22
  return format_output(output)
23
 
24
  def chat_function(history, user_input):