Spaces:
Sleeping
Sleeping
| from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
| from pydantic import ValidationError | |
| import json | |
| from pprint import pprint | |
| from llm.basemodel import EHRModel | |
| from llm.prompt import field_descriptions, TASK_INSTRUCTIONS, JSON_EXAMPLE | |
| from llm.models import get_model | |
| import time | |
| class VirtualNurseLLM: | |
| def __init__(self, base_url=None, model_name=None, api_key=None, model_type=None): | |
| self.client = None | |
| if model_name: | |
| self.client = get_model(model_name=model_name) | |
| self.model_name = model_name | |
| self.TASK_INSTRUCTIONS = TASK_INSTRUCTIONS | |
| self.field_descriptions = field_descriptions | |
| self.JSON_EXAMPLE = JSON_EXAMPLE | |
| self.ehr_data = {} | |
| self.chat_history = [] | |
| self.chat_history.append({"role": "assistant", "content": "สวัสดีค่ะ ดิฉัน มะลิ เป็นพยาบาลเสมือนที่จะมาดูแลการซักประวัตินะคะ"}) | |
| self.current_patient_response = None | |
| self.current_context = None | |
| self.debug = False | |
| self.current_prompt = None | |
| self.current_prompt_ehr = None | |
| self.current_question = None | |
| self.ending_text = "ขอบคุณที่ให้ข้อมูลค่ะ ฉันได้ข้อมูลที่ต้องการครบแล้วค่ะ ดิฉันจะบันทึกข้อมูลทั้งหมดนี้เพื่อส่งต่อให้แพทย์ดูแลคุณอย่างเหมาะสมค่ะ" | |
| def create_prompt(self, task_type): | |
| if task_type == "extract_ehr": | |
| system_instruction = self.TASK_INSTRUCTIONS.get("extract_ehr") | |
| elif task_type == "question": | |
| system_instruction = self.TASK_INSTRUCTIONS.get("question") | |
| elif task_type == "refactor": | |
| system_instruction = self.TASK_INSTRUCTIONS.get("refactor") | |
| else: | |
| raise ValueError("Invalid task type.") | |
| # system + user | |
| system_template = SystemMessagePromptTemplate.from_template(system_instruction) | |
| user_template = HumanMessagePromptTemplate.from_template("response: {patient_response}") | |
| prompt = ChatPromptTemplate.from_messages([system_template, user_template]) | |
| return prompt | |
| def gather_ehr(self, patient_response, max_retries=2): | |
| prompt = self.create_prompt("extract_ehr") | |
| messages = prompt.format_messages(ehr_data=self.ehr_data, patient_response=patient_response, example=self.JSON_EXAMPLE) | |
| self.current_prompt_ehr = messages[0].content | |
| response = self.client(messages=messages) | |
| if self.debug: | |
| pprint(f"gather ehr llm response: \n{response.content}\n") | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| try: | |
| json_content = self.extract_json_content(response.content) | |
| if self.debug: | |
| pprint(f"JSON after dumps:\n{json_content}\n") | |
| ehr_data = EHRModel.model_validate_json(json_content) | |
| # Update only missing parameters | |
| for key, value in ehr_data.model_dump().items(): | |
| if value not in [None, [], {}]: # Checks for None and empty lists or dicts | |
| print(f"Updating {key} with value {value}") | |
| self.ehr_data[key] = value | |
| return self.ehr_data | |
| except (ValidationError, json.JSONDecodeError) as e: | |
| print(f"Error parsing EHR data: {e} Retrying {retry_count}...") | |
| retry_count += 1 | |
| if retry_count < max_retries: | |
| retry_prompt = ( | |
| "กรุณาตรวจสอบให้แน่ใจว่าข้อมูลที่ให้มาอยู่ในรูปแบบ JSON ที่ถูกต้องตามโครงสร้างตัวอย่าง " | |
| "และแก้ไขปัญหาทางไวยากรณ์หรือรูปแบบที่ไม่ถูกต้อง รวมถึงให้ข้อมูลในรูปแบบที่สอดคล้องกัน " | |
| "ห้ามมีการ hallucination หากไม่เจอข้อมูลให้ใส่ค่า null " | |
| f"Attempt {retry_count + 1} of {max_retries}." | |
| ) | |
| messages = self.create_prompt("extract_ehr") + "\n\n# ลองใหม่: \n\n{retry_prompt} \n ## JSON เก่าที่มีปัญหา: \n{json_problem}" | |
| messages = messages.format_messages( | |
| ehr_data = self.ehr_data, | |
| patient_response=patient_response, | |
| example=self.JSON_EXAMPLE, | |
| retry_prompt=retry_prompt, | |
| json_problem=json_content | |
| ) | |
| self.current_prompt_ehr = messages[0].content | |
| print(f"กำลังลองใหม่ด้วย prompt ที่ปรับแล้ว: {retry_prompt}") | |
| response = self.client(messages=messages) | |
| # Final error message if retries are exhausted | |
| print("Failed to extract valid EHR data after multiple attempts. Generating new question.") | |
| return {"result": response, "error": "Failed to extract valid EHR data. Please try again."} | |
| def fetching_chat(self, patient_response, question_prompt): | |
| for field, description in self.field_descriptions.items(): | |
| # Find the next missing field and generate a question | |
| if field not in self.ehr_data or not self.ehr_data[field]: | |
| # Compile known patient information as context | |
| context = ", ".join( | |
| f"{key}: {value}" for key, value in self.ehr_data.items() if value | |
| ) | |
| print("fetching for ", f'"{field}":"{description}"') | |
| history_context = "\n".join( | |
| f"{entry['role']}: {entry['content']}" for entry in self.chat_history | |
| ) | |
| messages = ChatPromptTemplate.from_messages([question_prompt, history_context]) | |
| messages = messages.format_messages( | |
| description=f'"{field}":"{description}"', | |
| context=context, | |
| patient_response=patient_response, | |
| field_descriptions=self.field_descriptions, | |
| time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time() + 7*3600)) | |
| ) | |
| self.current_context = context | |
| self.current_prompt = messages[0].content | |
| start_time = time.time() | |
| response = self.client(messages=messages) | |
| print(f"Time after getting response from client: {time.time() - start_time} seconds") | |
| # Store generated question in chat history and return it | |
| self.current_question = response.content.strip() | |
| return self.current_question | |
| def refactor_ehr(self, current_question=None): | |
| patient_response = current_question or self.ending_text | |
| refactor_prompt = self.create_prompt("refactor") | |
| messages = ChatPromptTemplate.from_messages([refactor_prompt]) | |
| messages = messages.format_messages(patient_response="", ehr_data=self.ehr_data, chat_history=self.chat_history, time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time() + 7*3600))) | |
| response = self.client(messages=messages) | |
| json_content = self.extract_json_content(response.content) | |
| pprint(f"JSON after dumps:\n{json_content}\n") | |
| self.ehr_data = EHRModel.model_validate_json(json_content) | |
| print("Refactored EHR data ! Ending the process.") | |
| return patient_response | |
| def get_question(self, patient_response): | |
| question_prompt = self.create_prompt("question") | |
| # Update EHR data with the latest patient response | |
| start_time = time.time() | |
| ehr_data = self.gather_ehr(patient_response) | |
| print(f"Time after gathering EHR: {time.time() - start_time} seconds") | |
| if self.debug: | |
| pprint(ehr_data) | |
| self.current_question = self.fetching_chat(patient_response, question_prompt) or self.refactor_ehr() | |
| if self.ending_text in self.current_question: | |
| return self.refactor_ehr(self.current_question) | |
| return self.current_question | |
| def invoke(self, patient_response): | |
| if patient_response: | |
| self.chat_history.append({"role": "user", "content": patient_response}) | |
| question = self.get_question(patient_response) | |
| self.current_patient_response = patient_response | |
| self.chat_history.append({"role": "assistant", "content": question}) | |
| return question | |
| def slim_invoke(self, patient_response): | |
| start_time = time.time() | |
| user_message = HumanMessagePromptTemplate.from_template("response: {patient_response}") | |
| print(f"Time after creating user_message: {time.time() - start_time} seconds") | |
| start_time = time.time() | |
| messages = ChatPromptTemplate.from_messages([user_message]).format_messages(patient_response=patient_response) | |
| print(f"Time after formatting messages: {time.time() - start_time} seconds") | |
| start_time = time.time() | |
| response = self.client(messages=messages) | |
| print(f"Time after getting response from client: {time.time() - start_time} seconds") | |
| return response.content | |
| def extract_json_content(self, content): | |
| try: | |
| content = content.replace('\n', '').replace('\r', '') | |
| start = content.index('{') | |
| end = content.rindex('}') + 1 | |
| json_str = content[start:end] | |
| json_str = json_str.replace('None', 'null') | |
| return json_str | |
| except ValueError: | |
| print("JSON Parsing Error Occured: ", content) | |
| print("No valid JSON found in response") | |
| return None | |
| def reset(self): | |
| self.ehr_data = {} | |
| self.chat_history = [] | |
| self.current_question = None | |