MMedAgent_demo / src /ChatCAD_R /chat_bot_RAG.py
YPan0's picture
Upload folder using huggingface_hub
b6deff2 verified
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
import openai
from transformers import pipeline
import json
from text2vec import SentenceModel
from ChatCAD_R.cxr.prompt import prob2text,prob2text_zh
from ChatCAD_R.r2g.report_generate import reportGen
from ChatCAD_R.cxr.diagnosis import getJFImg,JFinfer,JFinit
from ChatCAD_R.engine_LLM.api import answer_quest, query_range
from ChatCAD_R.modality_identify import ModalityClip
fivedisease_zh={
"心脏肥大":0,
"肺水肿":1,
"肺实变":2,
"肺不张":3,
"胸腔积液":4,
}
class base_bot:
def start(self):
"""为当前会话新建chatbot"""
pass
def reset(self):
"""删除当前会话的chatbot"""
pass
def chat(self,message: str):
pass
class gpt_bot(base_bot):
def __init__(self, engine: str,api_key: str,):
"""初始化模型"""
self.agent=None
self.engine=engine
self.api_key=api_key
config_path = os.path.join(os.path.dirname(__file__), 'cxr', 'config', 'JF.json')
weights_path = os.path.join(os.path.dirname(__file__), 'weights', 'JFchexpert.pth')
msd_path = os.path.join(os.path.dirname(__file__), 'engine_LLM', 'dataset', 'msd_dict.json')
img_model, imgcfg = JFinit(config_path, weights_path)
self.imgcfg=imgcfg
self.img_model=img_model
self.reporter=reportGen()
self.modality=["chest x-ray", "panoramic dental x-ray", "knee mri","Mammography"]
self.translator = pipeline(model="zhaozh/radiology-report-en-zh-ft-base",device=0,max_length=500,truncation=True)
self.identifier=ModalityClip(self.modality)
self.sent_model = SentenceModel()
self.msd_dict=json.load(open(msd_path,'r',encoding='utf-8'))
os.environ['HTTP_PROXY'] = 'socks5h://127.0.0.1:1080'
os.environ['HTTPS_PROXY'] = 'socks5h://127.0.0.1:1080'
def ret_local(self,query:str,mode=1):
topic_range, [_,_]=query_range(self.sent_model,query,k=1,bar=0.0)
# Chinese
if mode==0:
return "https://"+self.msd_dict[topic_range[0]]
# English
else:
return "https://"+self.msd_dict[topic_range[0]].replace('www.msdmanuals.cn','www.merckmanuals.com')
# translate chest X-ray reports generated by r2g into Chinese.
def radio_en_zh(self, content: str):
output=self.translator(content)
report_zh=output[0]['generated_text']
return report_zh
def chat_with_gpt(self, prompt):
messages = [
{"role": "user", "content": prompt}
]
request_params = {
"model": self.engine,
"messages": messages
}
response = openai.ChatCompletion.create(**request_params)
return response['choices'][0]['message']['content']
def report_cxr_zh(self,img, mode:str='run'):
img1,img2=getJFImg(img,self.imgcfg)
text_report=self.reporter.report(img1)[0]
text_report=self.radio_en_zh(text_report)
prob=JFinfer(self.img_model,img2,self.imgcfg)
converter=prob2text_zh(prob,fivedisease_zh)
# default setting: promptB
res=converter.promptB()
prompt_report_zh=res+" 网络B生成了诊断报告:"+text_report
awesomePrompt_en="\nRefine the report of Network B based on results from Network A using English.Please do not mention Network A and \
Network B. Suppose you are a doctor writing findings for a chest x-ray report."
prompt_report=prompt_report_zh+awesomePrompt_en
refined_report = self.chat_with_gpt(prompt_report)
if mode=='debug':
return text_report,refined_report,prob.detach().cpu().numpy().tolist()
else:
return refined_report
def report_zh(self,img, mode:str='run'):
# identify modality
index=self.identifier.identify(img)
# call ModalitySpecificModel
if index==0:
return self.report_cxr_zh(img,mode),self.modality[index]
elif index==1:
# return self.report_dental_zh(img_path)
# The source code of the CAD network for dental x-rays is currently not planned to be open-sourced
# You can try it in the future online version of ChatCAD+.
return "Tooth CAD not publicly available!", self.modality[index]
elif index==2:
return "Knee CAD not publicly available!", self.modality[index]
else:
print("error!")
return
def chat(self,message: str, ref_record: str, force_generate=False):
report = self.chat_with_gpt(ref_record+'\n'+"If the following message includes a medical report, answer \"1\"; else, answer\"0\""+message)
if str(report) == str(1):
return self.chat_report(message, ref_record, force_generate)
refine_prompt="请根据以下内容概括患者的提问并对所涉及的疾病指出其全称:\n"
refined_message=self.chat_with_gpt(ref_record+'\n'+refine_prompt+message)
topic_range, [raw_topic, cos_sim]=query_range(self.sent_model,refined_message,k=5,bar=0.6)
if len(topic_range)==0:
response = self.chat_with_gpt(f"{ref_record}\nuser:**Answer in English**\n"+message)
response +="\nNote: No definitive evidence was found in the Merck Manual Professional Edition. Please adopt with caution."
return response
refine_prompt="请根据以下内容概括患者的提问:\n"
refined_message=self.chat_with_gpt(ref_record+'\n'+refine_prompt+message)
ret=answer_quest(refined_message,api_key=self.api_key,topic_base_dict=topic_range)
if ret==None:
response = self.chat_with_gpt(f"{ref_record}\nuser:**Answer in English**\n"+message)
response +="\nNote: No definitive evidence was found in the Merck Manual Professional Edition. Please adopt with caution."
message=response
else:
query,knowledge=ret
knowledge=knowledge.replace("\n\n","\n")
# needed_site=ret_website(query)
needed_site=self.ret_local(query,1)
try:
index = knowledge.index(":")
except ValueError:
index = -1
knowledge = knowledge[index+1:]
chat_message=f"{ref_record}\nuser:**answer in English*\n请参考以下知识来解答病人的问题“{message}”并给出分析,请注意保持语句通顺\n[{knowledge}]"
response = self.chat_with_gpt(chat_message)
message= response+f"\nNote: Relevant information is sourced from the Merck Manual Professional Edition. ({needed_site})"+""
return message
def chat_report(self,message: str, ref_record: str, force_generate=False):
if force_generate: # for specific medical suggestion generation task, normally set to false
query = None
ref_record = "Here is a medical image report for the patient:\n" + ref_record
abnormality_check_prompt = "你的任务:【根据以下医学影像报告,你需要找出提及的可能的疾病,或者概括患者需要注意的与疾病有关的异常情况并对所涉及的疾病指出其全称。"
abnormality_check_prompt += "*以中文回答*】\n"
abnormality_check_prompt += "报告中有提及可能导致该结果的疾病吗?如有,则请回复且仅回复报告中提及的该疾病的名称以及其中文名称,之后结束对话\n"
abnormality_check_prompt += "若无提及具体疾病,则报告中有较为明显的异常吗?如有,则请回复且只回复异常情况相关内容 (即省略正常情况的内容),将所有语句合并精炼成一小段话,"
abnormality_check_prompt += "并回复且仅回复该一小段话,之后结束对话\n"
abnormality_check_prompt += "若无明显异常,则回复0,且不回复0以外任何内容。"
abnormality_check = self.chat_with_gpt(abnormality_check_prompt + '\n' + ref_record)
if str(abnormality_check) == str(0):
check = 0
term_select_prompt = "Based on this report, find all medical terms that relates to potential diseases and randomly select one term. "
term_select_prompt += "Give your answer ONLY with that term. Don't include any other contents except for that term you selected."
term = self.chat_with_gpt(ref_record+'\n'+term_select_prompt)
refine_prompt = f"简述{term}并对{term}所相关的疾病指出其全称*以中文回答*"
refined_message=self.chat_with_gpt(refine_prompt+message)
else:
refine_prompt = abnormality_check + "\n请就此医学报告进行分析,给出最有可能的医学结论*以中文回答,仅回复得出的医学结论,且尽量精简精炼你的回答*"
refined_message=self.chat_with_gpt(refine_prompt+message)
check = 1
topic_range, [raw_topic, cos_sim]=query_range(self.sent_model,refined_message,k=5,bar=0.6)
if len(topic_range)==0:
response = self.chat_with_gpt(f"{abnormality_check}\nuser:*Answer in English**\n"+"Give medical evaluation and suggestion")
response +="\nNote: No definitive evidence was found in the Merck Manual Professional Edition. Please adopt with caution."
return response, check, query, abnormality_check, [raw_topic, cos_sim], None
ret=answer_quest(refined_message,api_key=self.api_key,topic_base_dict=topic_range)
if ret==None:
ab_select_prompt = "Based on this report, find a medical term that indicates an abnormality, and explain this abnormality in brief."
response = self.chat_with_gpt(ref_record+'\n'+ab_select_prompt)
response +="\nNote: No definitive evidence was found in the Merck Manual Professional Edition. Please adopt with caution."
message=response
else:
query,knowledge=ret
knowledge=knowledge.replace("\n\n","\n")
needed_site=self.ret_local(query,1)
try:
index = knowledge.index(":")
except ValueError:
index = -1
knowledge = knowledge[index+1:]
if str(abnormality_check) == str(0):
chat_message=f"{ref_record}\nuser:*answer in English*\n以下知识可能有关于患者症状,请参考它来回答患者问题'for this medical report, is everything ok?'并给予患者医学建议并给出分析,请注意保持语句通顺\n[{knowledge}]"
else:
chat_message=f"{abnormality_check}\nuser:*answer in English*\n以下知识可能有关于患者症状,请参考它来回答患者问题'for this medical report, is everything ok?'并给予患者医学建议并给出分析,请注意保持语句通顺\n[{knowledge}]"
response = self.chat_with_gpt(chat_message)
message= response+f"\nNote: Relevant information is sourced from the Merck Manual Professional Edition. ({needed_site})"+""
return message, check, query, abnormality_check, [raw_topic, cos_sim], knowledge
else:
query = None
ref_record = "Here is a medical image report for the patient:\n" + ref_record
abnormality_check_prompt = "你的任务:【根据以下医学影像报告,你需要找出提及的可能的疾病,或者概括患者需要注意的与疾病有关的异常情况并对所涉及的疾病指出其全称。"
abnormality_check_prompt += "*以中文回答*】\n"
abnormality_check_prompt += "报告中有提及可能导致该结果的疾病吗?如有,则请回复且仅回复报告中提及的该疾病的名称以及其中文名称,之后结束对话\n"
abnormality_check_prompt += "若无提及具体疾病,则报告中有较为明显的异常吗?如有,则请回复且只回复异常情况相关内容 (即省略正常情况的内容),将所有语句合并精炼成一小段话,"
abnormality_check_prompt += "并回复且仅回复该一小段话,之后结束对话\n"
abnormality_check_prompt += "若无明显异常,则回复0,且不回复0以外任何内容。"
abnormality_check = self.chat_with_gpt(abnormality_check_prompt + '\n' + ref_record)
if str(abnormality_check) == str(0):
check = 0
response = "No abnormalities were found in your report. It seems you are all good!"
return response, check, query, abnormality_check, [None, None], None
else:
refine_prompt = abnormality_check + "\n请就此医学报告进行分析,给出最有可能的医学结论*以中文回答,仅回复得出的医学结论,且尽量精简精炼你的回答*"
refined_message=self.chat_with_gpt(refine_prompt+message)
check = 1
topic_range, [raw_topic, cos_sim]=query_range(self.sent_model,refined_message,k=5,bar=0.6)
if len(topic_range)==0:
response = self.chat_with_gpt(f"{abnormality_check}\nuser:*Answer in English**\n"+"Give medical evaluation and suggestion")
response +="\nNote: No definitive evidence was found in the Merck Manual Professional Edition. Please adopt with caution."
return response, check, query, abnormality_check, [raw_topic, cos_sim], None
ret=answer_quest(refined_message,api_key=self.api_key,topic_base_dict=topic_range)
if ret==None:
response = refined_message
response +="\nNote: Failed to retrieve relavant information from the Merck Manual Professional Edition. Please adopt with caution."
message=response
else:
query,knowledge=ret
knowledge=knowledge.replace("\n\n","\n")
needed_site=self.ret_local(query,1)
try:
index = knowledge.index(":")
except ValueError:
index = -1
knowledge = knowledge[index+1:]
chat_message=f"{abnormality_check}\nuser:*answer in English*\n以下知识可能有关于患者症状,请参考它来给予患者医学建议并给出分析,请注意保持语句通顺\n[{knowledge}]"
response = self.chat_with_gpt(chat_message)
message= response+f"\nNote: Relevant information is sourced from the Merck Manual Professional Edition. ({needed_site})"+""
return message, check, query, abnormality_check, [raw_topic, cos_sim], knowledge