HFHAB commited on
Commit
ad724cd
·
verified ·
1 Parent(s): a06108c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -4
main.py CHANGED
@@ -10,15 +10,20 @@ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
10
 
11
  class Item(BaseModel):
12
  prompt: str
 
13
  system_prompt: str
14
  temperature: float = 0.0
15
  max_new_tokens: int = 1048
16
  top_p: float = 0.15
17
  repetition_penalty: float = 1.0
18
 
19
- def format_prompt(example):
20
- text = f"### Instruction: {example['input']}\n ### Response: {example['output']}"
21
- return text
 
 
 
 
22
 
23
  def generate(item: Item):
24
  temperature = float(item.temperature)
@@ -35,7 +40,7 @@ def generate(item: Item):
35
  seed=42,
36
  )
37
 
38
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}")
39
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
40
  output = ""
41
 
 
10
 
11
  class Item(BaseModel):
12
  prompt: str
13
+ history: list
14
  system_prompt: str
15
  temperature: float = 0.0
16
  max_new_tokens: int = 1048
17
  top_p: float = 0.15
18
  repetition_penalty: float = 1.0
19
 
20
+ def format_prompt(message, history):
21
+ prompt = "<s>"
22
+ for user_prompt, bot_response in history:
23
+ prompt += f"[INST] {user_prompt} [/INST]"
24
+ prompt += f" {bot_response}</s> "
25
+ prompt += f"[INST] {message} [/INST]"
26
+ return prompt
27
 
28
  def generate(item: Item):
29
  temperature = float(item.temperature)
 
40
  seed=42,
41
  )
42
 
43
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
  output = ""
46