Bsbell21 commited on
Commit
0f2df3f
·
verified ·
1 Parent(s): ac79e2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -9,12 +9,28 @@ model = AutoModelForCausalLM.from_pretrained(
9
  return_dict=True,
10
  device_map="auto"
11
  )
12
- tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
 
 
13
 
14
  # Load the Lora model
15
  model = PeftModel.from_pretrained(model, peft_model_id)
16
 
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def make_inference(product_name, product_description):
19
  batch = tokenizer(
20
  f"### Product and Description:\n{product_name}: {product_description}\n\n### Ad:",
@@ -28,7 +44,7 @@ def make_inference(product_name, product_description):
28
 
29
  return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
30
 
31
-
32
  if __name__ == "__main__":
33
  # make a gradio interface
34
  import gradio as gr
 
9
  return_dict=True,
10
  device_map="auto"
11
  )
12
+ #tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
13
+
14
+ mixtral_tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
15
 
16
  # Load the Lora model
17
  model = PeftModel.from_pretrained(model, peft_model_id)
18
 
19
+ def input_from_text(product, description):
20
+ return f"<s>[INST]Below is a product and description, please write a marketing email for this product.\n\n### Product:\n{product}\n### Description:\n{description}\n\n### Marketing Email:[/INST]"
21
+
22
+ def make_inference(product, description):
23
+ inputs = mixtral_tokenizer(input_from_text(product, description), return_tensors="pt")
24
+
25
+ outputs = merged_model.generate(
26
+ **inputs,
27
+ max_new_tokens=150,
28
+ generation_kwargs={"repetition_penalty" : 1.7}
29
+ )
30
+ # print(mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True))
31
+ result = mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True).split("[/INST]")[1]
32
+ return result
33
+ '''
34
  def make_inference(product_name, product_description):
35
  batch = tokenizer(
36
  f"### Product and Description:\n{product_name}: {product_description}\n\n### Ad:",
 
44
 
45
  return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
46
 
47
+ '''
48
  if __name__ == "__main__":
49
  # make a gradio interface
50
  import gradio as gr