Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,14 +30,10 @@ class OrcaChatBot:
|
|
| 30 |
self.system_message = system_message
|
| 31 |
|
| 32 |
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
|
| 33 |
-
|
| 34 |
-
prompt = f"system\n{self.system_message}\nuser\n{user_message}\nassistant"
|
| 35 |
-
|
| 36 |
-
# Encode the prompt
|
| 37 |
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
|
| 38 |
input_ids = inputs["input_ids"].to(self.model.device)
|
| 39 |
|
| 40 |
-
# Generate a response
|
| 41 |
output_ids = self.model.generate(
|
| 42 |
input_ids,
|
| 43 |
max_length=input_ids.shape[1] + max_new_tokens,
|
|
@@ -45,10 +41,9 @@ class OrcaChatBot:
|
|
| 45 |
top_p=top_p,
|
| 46 |
repetition_penalty=repetition_penalty,
|
| 47 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 48 |
-
do_sample=True
|
| 49 |
)
|
| 50 |
|
| 51 |
-
# Decode the generated response
|
| 52 |
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 53 |
|
| 54 |
return response
|
|
@@ -75,5 +70,4 @@ iface = gr.Interface(
|
|
| 75 |
theme="ParityError/Anime"
|
| 76 |
)
|
| 77 |
|
| 78 |
-
# Launch the Gradio interface
|
| 79 |
iface.launch()
|
|
|
|
| 30 |
self.system_message = system_message
|
| 31 |
|
| 32 |
def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
|
| 33 |
+
prompt = f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" if self.conversation_history is None else self.conversation_history + f"<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant"
|
|
|
|
|
|
|
|
|
|
| 34 |
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
|
| 35 |
input_ids = inputs["input_ids"].to(self.model.device)
|
| 36 |
|
|
|
|
| 37 |
output_ids = self.model.generate(
|
| 38 |
input_ids,
|
| 39 |
max_length=input_ids.shape[1] + max_new_tokens,
|
|
|
|
| 41 |
top_p=top_p,
|
| 42 |
repetition_penalty=repetition_penalty,
|
| 43 |
pad_token_id=self.tokenizer.eos_token_id,
|
| 44 |
+
do_sample=True
|
| 45 |
)
|
| 46 |
|
|
|
|
| 47 |
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 48 |
|
| 49 |
return response
|
|
|
|
| 70 |
theme="ParityError/Anime"
|
| 71 |
)
|
| 72 |
|
|
|
|
| 73 |
iface.launch()
|