Medical_Chat_Autocompletion / conversation.py
warleagle's picture
Upload 2 files
fe83af5 verified
class Conversation:
def __init__(
self,
):
self.message_template = "<s>{role}\n{content}</s>"
self.response_template = "<s>bot\n"
self.messages = [{
"role": "system",
"content": "Предложи ответ оператора технической поддержки на вопрос пользователя из чата."
}]
def add_user_message(self, message):
self.messages.append({
"role": "user",
"content": message
})
def add_bot_message(self, message):
self.messages.append({
"role": "bot",
"content": message
})
def get_prompt(self):
final_text = ""
for message in self.messages:
message_text = self.message_template.format(**message)
final_text += message_text
final_text += "<s>bot\n"
return final_text.strip()
def generate(model, tokenizer, prompt, generation_config):
data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
data = {k: v.to(model.device) for k, v in data.items()}
output_ids = model.generate(
**data,
generation_config=generation_config
)[0]
output_ids = output_ids[len(data["input_ids"][0]):]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
return output.strip()