Spaces:
Runtime error
Runtime error
| class Conversation: | |
| def __init__( | |
| self, | |
| ): | |
| self.message_template = "<s>{role}\n{content}</s>" | |
| self.response_template = "<s>bot\n" | |
| self.messages = [{ | |
| "role": "system", | |
| "content": "Предложи ответ оператора технической поддержки на вопрос пользователя из чата." | |
| }] | |
| def add_user_message(self, message): | |
| self.messages.append({ | |
| "role": "user", | |
| "content": message | |
| }) | |
| def add_bot_message(self, message): | |
| self.messages.append({ | |
| "role": "bot", | |
| "content": message | |
| }) | |
| def get_prompt(self): | |
| final_text = "" | |
| for message in self.messages: | |
| message_text = self.message_template.format(**message) | |
| final_text += message_text | |
| final_text += "<s>bot\n" | |
| return final_text.strip() | |
| def generate(model, tokenizer, prompt, generation_config): | |
| data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) | |
| data = {k: v.to(model.device) for k, v in data.items()} | |
| output_ids = model.generate( | |
| **data, | |
| generation_config=generation_config | |
| )[0] | |
| output_ids = output_ids[len(data["input_ids"][0]):] | |
| output = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| return output.strip() |