Ctaake commited on
Commit
88db26a
·
verified ·
1 Parent(s): 2a8b052

Switch nous research

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -10,12 +10,13 @@ from datetime import datetime
10
  checkpoint = "CohereForAI/c4ai-command-r-v01"
11
  checkpoint = "mistralai/Mistral-7B-Instruct-v0.1"
12
  checkpoint = "google/gemma-1.1-7b-it"
 
13
  path_to_log = "FlaggedFalse.txt"
14
 
15
  # Inference client with the model (And HF-token if needed)
16
  client = InferenceClient(checkpoint)
17
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
18
- if checkpoint == "mistralai/Mistral-7B-Instruct-v0.1":
19
  # Tokenizer chat template correction(Only works for mistral models)
20
  chat_template = open("mistral-instruct.jinja").read()
21
  chat_template = chat_template.replace(' ', '').replace('\n', '')
@@ -58,7 +59,7 @@ def inference(message, history, systemPrompt=SYSTEM_PROMPT+SYSTEM_PROMPT_PLUS, t
58
  seed=random.randint(0, 999999999),
59
  )
60
  # Generating the response by passing the prompt in right format plus the client settings
61
- stream = client.text_generation(format_prompt_gemma(message, history, systemPrompt),
62
  **client_settings)
63
  # Reading the stream
64
  partial_response = ""
 
10
  checkpoint = "CohereForAI/c4ai-command-r-v01"
11
  checkpoint = "mistralai/Mistral-7B-Instruct-v0.1"
12
  checkpoint = "google/gemma-1.1-7b-it"
13
+ checkpoint = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"
14
  path_to_log = "FlaggedFalse.txt"
15
 
16
  # Inference client with the model (And HF-token if needed)
17
  client = InferenceClient(checkpoint)
18
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
19
+ if checkpoint == "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO":
20
  # Tokenizer chat template correction(Only works for mistral models)
21
  chat_template = open("mistral-instruct.jinja").read()
22
  chat_template = chat_template.replace(' ', '').replace('\n', '')
 
59
  seed=random.randint(0, 999999999),
60
  )
61
  # Generating the response by passing the prompt in right format plus the client settings
62
+ stream = client.text_generation(format_prompt(message, history, systemPrompt),
63
  **client_settings)
64
  # Reading the stream
65
  partial_response = ""