| import gradio as gr |
| import os |
| import sys |
| import json |
| import gc |
| import numpy as np |
| from vllm import LLM, SamplingParams |
| from jinja2 import Template |
| from typing import List |
| import types |
| from tooluniverse import ToolUniverse |
| from gradio import ChatMessage |
| from .toolrag import ToolRAGModel |
| import torch |
| import logging |
| from difflib import SequenceMatcher |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
| from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format |
|
|
| class TxAgent: |
| def __init__(self, model_name, |
| rag_model_name, |
| tool_files_dict=None, |
| enable_finish=True, |
| enable_rag=True, |
| enable_summary=False, |
| init_rag_num=2, |
| step_rag_num=4, |
| summary_mode='step', |
| summary_skip_last_k=0, |
| summary_context_length=None, |
| force_finish=True, |
| avoid_repeat=True, |
| seed=None, |
| enable_checker=False, |
| enable_chat=False, |
| additional_default_tools=None): |
| self.model_name = model_name |
| self.tokenizer = None |
| self.terminators = None |
| self.rag_model_name = rag_model_name |
| self.tool_files_dict = tool_files_dict |
| self.model = None |
| self.rag_model = ToolRAGModel(rag_model_name) |
| self.tooluniverse = None |
| self.prompt_multi_step = "You are a medical assistant solving clinical oversight issues step-by-step using provided tools." |
| self.self_prompt = "Follow instructions precisely." |
| self.chat_prompt = "You are a helpful assistant for clinical queries." |
| self.enable_finish = enable_finish |
| self.enable_rag = enable_rag |
| self.enable_summary = enable_summary |
| self.summary_mode = summary_mode |
| self.summary_skip_last_k = summary_skip_last_k |
| self.summary_context_length = summary_context_length |
| self.init_rag_num = init_rag_num |
| self.step_rag_num = step_rag_num |
| self.force_finish = force_finish |
| self.avoid_repeat = avoid_repeat |
| self.seed = seed |
| self.enable_checker = enable_checker |
| self.additional_default_tools = additional_default_tools |
| logger.debug("TxAgent initialized with parameters: %s", self.__dict__) |
|
|
| def init_model(self): |
| self.load_models() |
| self.load_tooluniverse() |
| self.load_tool_desc_embedding() |
|
|
| def print_self_values(self): |
| for attr, value in self.__dict__.items(): |
| logger.debug("%s: %s", attr, value) |
|
|
| def load_models(self, model_name=None): |
| if model_name is not None and model_name == self.model_name: |
| return f"The model {model_name} is already loaded." |
| if model_name: |
| self.model_name = model_name |
|
|
| self.model = LLM(model=self.model_name, dtype="float16") |
| self.chat_template = Template(self.model.get_tokenizer().chat_template) |
| self.tokenizer = self.model.get_tokenizer() |
| logger.info("Model %s loaded successfully", self.model_name) |
| return f"Model {self.model_name} loaded successfully." |
|
|
| def load_tooluniverse(self): |
| self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict) |
| self.tooluniverse.load_tools() |
| special_tools = self.tooluniverse.prepare_tool_prompts( |
| self.tooluniverse.tool_category_dicts["special_tools"]) |
| self.special_tools_name = [tool['name'] for tool in special_tools] |
| logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name)) |
|
|
| def load_tool_desc_embedding(self): |
| self.rag_model.load_tool_desc_embedding(self.tooluniverse) |
| logger.debug("Tool description embeddings loaded") |
|
|
| def rag_infer(self, query, top_k=5): |
| return self.rag_model.rag_infer(query, top_k) |
|
|
| def initialize_tools_prompt(self, call_agent, call_agent_level, message): |
| picked_tools_prompt = [] |
| |
| if "use external tools" not in message.lower(): |
| picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False) |
| else: |
| picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=call_agent) |
| if call_agent: |
| call_agent_level += 1 |
| if call_agent_level >= 2: |
| call_agent = False |
| if self.enable_rag: |
| picked_tools_prompt += self.tool_RAG(message=message, rag_num=self.init_rag_num) |
| return picked_tools_prompt, call_agent_level |
|
|
| def initialize_conversation(self, message, conversation=None, history=None): |
| if conversation is None: |
| conversation = [] |
|
|
| conversation = self.set_system_prompt(conversation, self.prompt_multi_step) |
| if history: |
| conversation.extend( |
| {"role": h['role'], "content": h['content']} |
| for h in history if h['role'] in ['user', 'assistant'] |
| ) |
| conversation.append({"role": "user", "content": message}) |
| logger.debug("Conversation initialized with %d messages", len(conversation)) |
| return conversation |
|
|
| def tool_RAG(self, message=None, picked_tool_names=None, |
| existing_tools_prompt=None, rag_num=4, return_call_result=False): |
| extra_factor = 10 |
| if picked_tool_names is None: |
| picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor) |
|
|
| picked_tool_names = [ |
| tool for tool in picked_tool_names |
| if tool not in self.special_tools_name |
| ][:rag_num] |
| picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names) |
| picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools) |
| logger.debug("RAG selected %d tools: %s", len(picked_tool_names), picked_tool_names) |
| if return_call_result: |
| return picked_tools_prompt, picked_tool_names |
| return picked_tools_prompt |
|
|
| def add_special_tools(self, tools, call_agent=False): |
| if self.enable_finish: |
| tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True)) |
| logger.debug("Finish tool added") |
| if call_agent and "use external tools" in self.prompt_multi_step.lower(): |
| tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True)) |
| logger.debug("CallAgent tool added") |
| elif self.enable_rag and "use external tools" in self.prompt_multi_step.lower(): |
| tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True)) |
| logger.debug("Tool_RAG tool added") |
| if self.additional_default_tools: |
| for tool_name in self.additional_default_tools: |
| tool_prompt = self.tooluniverse.get_one_tool_by_one_name(tool_name, return_prompt=True) |
| if tool_prompt: |
| tools.append(tool_prompt) |
| logger.debug("%s tool added", tool_name) |
| return tools |
|
|
| def add_finish_tools(self, tools): |
| tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True)) |
| logger.debug("Finish tool added") |
| return tools |
|
|
| def set_system_prompt(self, conversation, sys_prompt): |
| if not conversation: |
| conversation.append({"role": "system", "content": sys_prompt}) |
| else: |
| conversation[0] = {"role": "system", "content": sys_prompt} |
| return conversation |
|
|
| def run_function_call(self, fcall_str, return_message=False, |
| existing_tools_prompt=None, message_for_call_agent=None, |
| call_agent=False, call_agent_level=None, temperature=None): |
| function_call_json, message = self.tooluniverse.extract_function_call_json( |
| fcall_str, return_message=return_message, verbose=False) |
| call_results = [] |
| special_tool_call = '' |
| if function_call_json: |
| for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]: |
| logger.debug("Tool Call: %s", func) |
| if func["name"] == 'Finish': |
| special_tool_call = 'Finish' |
| break |
| elif func["name"] == 'Tool_RAG': |
| new_tools_prompt, call_result = self.tool_RAG( |
| message=message, existing_tools_prompt=existing_tools_prompt, |
| rag_num=self.step_rag_num, return_call_result=True) |
| existing_tools_prompt += new_tools_prompt |
| elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2: |
| solution_plan = func['arguments']['solution'] |
| full_message = ( |
| message_for_call_agent + "\nFollow this plan: " + str(solution_plan) |
| ) |
| call_result = self.run_multistep_agent( |
| full_message, temperature=temperature, max_new_tokens=512, |
| max_token=2048, call_agent=False, call_agent_level=call_agent_level) |
| call_result = call_result.split('[FinalAnswer]')[-1].strip() if call_result else "⚠️ No content from sub-agent." |
| else: |
| call_result = self.tooluniverse.run_one_function(func) |
|
|
| call_id = self.tooluniverse.call_id_gen() |
| func["call_id"] = call_id |
| logger.debug("Tool Call Result: %s", call_result) |
| call_results.append({ |
| "role": "tool", |
| "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id}) |
| }) |
| else: |
| call_results.append({ |
| "role": "tool", |
| "content": json.dumps({"content": "Invalid function call format."}) |
| }) |
|
|
| revised_messages = [{ |
| "role": "assistant", |
| "content": message.strip() if message else "", |
| "tool_calls": json.dumps(function_call_json) |
| }] + call_results |
| return revised_messages, existing_tools_prompt, special_tool_call |
|
|
| def run_function_call_stream(self, fcall_str, return_message=False, |
| existing_tools_prompt=None, message_for_call_agent=None, |
| call_agent=False, call_agent_level=None, temperature=None, |
| return_gradio_history=True): |
| function_call_json, message = self.tooluniverse.extract_function_call_json( |
| fcall_str, return_message=return_message, verbose=False) |
| call_results = [] |
| special_tool_call = '' |
| gradio_history = [] if return_gradio_history else None |
| if function_call_json: |
| for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]: |
| if func["name"] == 'Finish': |
| special_tool_call = 'Finish' |
| break |
| elif func["name"] == 'Tool_RAG': |
| new_tools_prompt, call_result = self.tool_RAG( |
| message=message, existing_tools_prompt=existing_tools_prompt, |
| rag_num=self.step_rag_num, return_call_result=True) |
| existing_tools_prompt += new_tools_prompt |
| elif func["name"] == 'DirectResponse': |
| call_result = func['arguments']['response'] |
| special_tool_call = 'DirectResponse' |
| elif func["name"] == 'RequireClarification': |
| call_result = func['arguments']['unclear_question'] |
| special_tool_call = 'RequireClarification' |
| elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2: |
| solution_plan = func['arguments']['solution'] |
| full_message = ( |
| message_for_call_agent + "\nFollow this plan: " + str(solution_plan) |
| ) |
| sub_agent_task = "Sub TxAgent plan: " + str(solution_plan) |
| call_result = yield from self.run_gradio_chat( |
| full_message, history=[], temperature=temperature, |
| max_new_tokens=512, max_token=2048, call_agent=False, |
| call_agent_level=call_agent_level, conversation=None, |
| sub_agent_task=sub_agent_task) |
| call_result = call_result.split('[FinalAnswer]')[-1] if call_result else "⚠️ No content from sub-agent." |
| else: |
| call_result = self.tooluniverse.run_one_function(func) |
|
|
| call_id = self.tooluniverse.call_id_gen() |
| func["call_id"] = call_id |
| call_results.append({ |
| "role": "tool", |
| "content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id}) |
| }) |
| if return_gradio_history and func["name"] != 'Finish': |
| title = f"{'🧰' if func['name'] == 'Tool_RAG' else '⚒️'} {func['name']}" |
| gradio_history.append(ChatMessage( |
| role="assistant", content=str(call_result), |
| metadata={"title": title, "log": str(func['arguments'])} |
| )) |
| else: |
| call_results.append({ |
| "role": "tool", |
| "content": json.dumps({"content": "Invalid function call format."}) |
| }) |
|
|
| revised_messages = [{ |
| "role": "assistant", |
| "content": message.strip() if message else "", |
| "tool_calls": json.dumps(function_call_json) |
| }] + call_results |
| return revised_messages, existing_tools_prompt, special_tool_call, gradio_history |
|
|
| def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token): |
| if conversation[-1]['role'] == 'assistant': |
| conversation.append( |
| {'role': 'tool', 'content': 'Errors occurred; provide final answer with current info.'}) |
| finish_tools_prompt = self.add_finish_tools([]) |
| output = self.llm_infer( |
| messages=conversation, temperature=temperature, tools=finish_tools_prompt, |
| output_begin_string='[FinalAnswer]', max_new_tokens=max_new_tokens, max_token=max_token) |
| logger.debug("Unfinished reasoning output: %s", output) |
| return output |
|
|
| def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int, |
| max_token: int, max_round: int = 3, call_agent=False, call_agent_level=0): |
| logger.debug("Starting multistep agent for message: %s", message[:100]) |
| picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( |
| call_agent, call_agent_level, message) |
| conversation = self.initialize_conversation(message) |
| outputs = [] |
| last_outputs = [] |
| next_round = True |
| current_round = 0 |
| token_overflow = False |
| enable_summary = False |
| last_status = {} |
|
|
| if self.enable_checker: |
| checker = ReasoningTraceChecker(message, conversation) |
|
|
| |
| clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis'] |
| has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords) |
|
|
| while next_round and current_round < max_round: |
| current_round += 1 |
| if last_outputs: |
| function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call( |
| last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt, |
| message_for_call_agent=message, call_agent=call_agent, |
| call_agent_level=call_agent_level, temperature=temperature) |
|
|
| if special_tool_call == 'Finish': |
| next_round = False |
| conversation.extend(function_call_messages) |
| content = function_call_messages[0]['content'] |
| return content.split('[FinalAnswer]')[-1] if content else "❌ No content after Finish." |
|
|
| if (self.enable_summary or token_overflow) and not call_agent: |
| enable_summary = True |
| last_status = self.function_result_summary( |
| conversation, status=last_status, enable_summary=enable_summary) |
|
|
| if function_call_messages: |
| conversation.extend(function_call_messages) |
| outputs.append(tool_result_format(function_call_messages)) |
| else: |
| next_round = False |
| return ''.join(last_outputs).replace("</s>", "") |
|
|
| if self.enable_checker: |
| good_status, wrong_info = checker.check_conversation() |
| if not good_status: |
| logger.warning("Checker error: %s", wrong_info) |
| break |
|
|
| |
| tools = [] if has_clinical_data else picked_tools_prompt |
| last_outputs = [] |
| last_outputs_str, token_overflow = self.llm_infer( |
| messages=conversation, temperature=temperature, tools=tools, |
| max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True) |
| if last_outputs_str is None: |
| if self.force_finish: |
| return self.get_answer_based_on_unfinished_reasoning( |
| conversation, temperature, max_new_tokens, max_token) |
| return "❌ Token limit exceeded." |
| last_outputs.append(last_outputs_str) |
|
|
| if current_round >= max_round: |
| logger.warning("Max rounds exceeded") |
| if self.force_finish: |
| return self.get_answer_based_on_unfinished_reasoning( |
| conversation, temperature, max_new_tokens, max_token) |
| return None |
|
|
| def build_logits_processor(self, messages, llm): |
| tokenizer = llm.get_tokenizer() |
| if self.avoid_repeat and len(messages) > 2: |
| assistant_messages = [ |
| m['content'] for m in messages[-3:] if m['role'] == 'assistant' |
| ][:2] |
| forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages] |
| |
| unique_sentences = [] |
| for msg in assistant_messages: |
| sentences = msg.split('. ') |
| for s in sentences: |
| if not s: |
| continue |
| is_unique = True |
| for seen_s in unique_sentences: |
| if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9: |
| is_unique = False |
| break |
| if is_unique: |
| unique_sentences.append(s) |
| forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences] |
| return [NoRepeatSentenceProcessor(forbidden_ids, 10)] |
| return None |
|
|
| def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None, |
| max_new_tokens=512, max_token=2048, skip_special_tokens=True, |
| model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False): |
| if model is None: |
| model = self.model |
|
|
| logits_processor = self.build_logits_processor(messages, model) |
| sampling_params = SamplingParams( |
| temperature=temperature, |
| max_tokens=max_new_tokens, |
| seed=seed if seed is not None else self.seed, |
| logits_processors=logits_processor |
| ) |
|
|
| prompt = self.chat_template.render(messages=messages, tools=tools, add_generation_prompt=True) |
| if output_begin_string: |
| prompt += output_begin_string |
|
|
| if check_token_status and max_token: |
| num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0]) |
| if num_input_tokens > max_token: |
| torch.cuda.empty_cache() |
| gc.collect() |
| logger.info("Token overflow: %d > %d", num_input_tokens, max_token) |
| return None, True |
| logger.debug("Input tokens: %d", num_input_tokens) |
|
|
| output = model.generate(prompt, sampling_params=sampling_params) |
| output = output[0].outputs[0].text |
| logger.debug("Inference output: %s", output[:100]) |
| torch.cuda.empty_cache() |
| if check_token_status: |
| return output, False |
| return output |
|
|
| def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int): |
| logger.debug("Starting self agent") |
| conversation = self.set_system_prompt([], self.self_prompt) |
| conversation.append({"role": "user", "content": message}) |
| return self.llm_infer(messages=conversation, temperature=temperature, |
| max_new_tokens=max_new_tokens, max_token=max_token) |
|
|
| def run_chat_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int): |
| logger.debug("Starting chat agent") |
| conversation = self.set_system_prompt([], self.chat_prompt) |
| conversation.append({"role": "user", "content": message}) |
| return self.llm_infer(messages=conversation, temperature=temperature, |
| max_new_tokens=max_new_tokens, max_token=max_token) |
|
|
| def run_format_agent(self, message: str, answer: str, temperature: float, max_new_tokens: int, max_token: int): |
| logger.debug("Starting format agent") |
| if '[FinalAnswer]' in answer: |
| possible_final_answer = answer.split("[FinalAnswer]")[-1] |
| elif "\n\n" in answer: |
| possible_final_answer = answer.split("\n\n")[-1] |
| else: |
| possible_final_answer = answer.strip() |
|
|
| if len(possible_final_answer) >= 1 and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']: |
| return possible_final_answer[0] |
| elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']: |
| return possible_final_answer[0] |
|
|
| conversation = self.set_system_prompt( |
| [], "Transform the answer to a single letter: 'A', 'B', 'C', 'D', or 'E'.") |
| conversation.append({"role": "user", "content": f"Original: {message}\nAnswer: {answer}\nFinal answer (letter):"}) |
| return self.llm_infer(messages=conversation, temperature=temperature, |
| max_new_tokens=max_new_tokens, max_token=max_token) |
|
|
| def run_summary_agent(self, thought_calls: str, function_response: str, |
| temperature: float, max_new_tokens: int, max_token: int): |
| logger.debug("Starting summary agent") |
| prompt = f"""Thought and function calls: {thought_calls} |
| Function responses: {function_response} |
| Summarize the function responses in one sentence with all necessary information.""" |
| conversation = [{"role": "user", "content": prompt}] |
| output = self.llm_infer(messages=conversation, temperature=temperature, |
| max_new_tokens=max_new_tokens, max_token=max_token) |
| if '[' in output: |
| output = output.split('[')[0] |
| return output |
|
|
| def function_result_summary(self, input_list, status, enable_summary): |
| if 'tool_call_step' not in status: |
| status['tool_call_step'] = 0 |
| if 'step' not in status: |
| status['step'] = 0 |
| status['step'] += 1 |
|
|
| for idx in range(len(input_list)): |
| pos_id = len(input_list) - idx - 1 |
| if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]: |
| if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']): |
| status['tool_call_step'] += 1 |
| break |
|
|
| if not enable_summary: |
| return status |
|
|
| if 'summarized_index' not in status: |
| status['summarized_index'] = 0 |
| if 'summarized_step' not in status: |
| status['summarized_step'] = 0 |
| if 'previous_length' not in status: |
| status['previous_length'] = 0 |
| if 'history' not in status: |
| status['history'] = [] |
|
|
| status['history'].append( |
| self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) |
|
|
| idx = status['summarized_index'] |
| function_response = '' |
| this_thought_calls = None |
| while idx < len(input_list): |
| if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \ |
| (self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length): |
| if input_list[idx]['role'] == 'assistant': |
| if 'Tool_RAG' in str(input_list[idx]['tool_calls']): |
| this_thought_calls = None |
| else: |
| if function_response: |
| status['summarized_step'] += 1 |
| result_summary = self.run_summary_agent( |
| thought_calls=this_thought_calls, function_response=function_response, |
| temperature=0.1, max_new_tokens=512, max_token=2048) |
| input_list.insert( |
| last_call_idx + 1, {'role': 'tool', 'content': result_summary}) |
| status['summarized_index'] = last_call_idx + 2 |
| idx += 1 |
| last_call_idx = idx |
| this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls'] |
| function_response = '' |
| elif input_list[idx]['role'] == 'tool' and this_thought_calls: |
| function_response += input_list[idx]['content'] |
| del input_list[idx] |
| idx -= 1 |
| else: |
| break |
| idx += 1 |
|
|
| if function_response: |
| status['summarized_step'] += 1 |
| result_summary = self.run_summary_agent( |
| thought_calls=this_thought_calls, function_response=function_response, |
| temperature=0.1, max_new_tokens=512, max_token=2048) |
| tool_calls = json.loads(input_list[last_call_idx]['tool_calls']) |
| for tool_call in tool_calls: |
| del tool_call['call_id'] |
| input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls) |
| input_list.insert( |
| last_call_idx + 1, {'role': 'tool', 'content': result_summary}) |
| status['summarized_index'] = last_call_idx + 2 |
|
|
| return status |
|
|
| def update_parameters(self, **kwargs): |
| updated_attributes = {} |
| for key, value in kwargs.items(): |
| if hasattr(self, key): |
| setattr(self, key, value) |
| updated_attributes[key] = value |
| logger.debug("Updated parameters: %s", updated_attributes) |
| return updated_attributes |
|
|
| def run_gradio_chat(self, message: str, history: list, temperature: float, |
| max_new_tokens: int, max_token: int, call_agent: bool, |
| conversation: gr.State, max_round: int = 3, seed: int = None, |
| call_agent_level: int = 0, sub_agent_task: str = None, |
| uploaded_files: list = None): |
| logger.debug("Chat started, message: %s", message[:100]) |
| if not message or len(message.strip()) < 5: |
| yield "Please provide a valid message or upload files to analyze." |
| return |
|
|
| if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"): |
| return |
|
|
| |
| clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis'] |
| has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords) |
| call_agent = call_agent and not has_clinical_data |
|
|
| picked_tools_prompt, call_agent_level = self.initialize_tools_prompt( |
| call_agent, call_agent_level, message) |
| conversation = self.initialize_conversation( |
| message, conversation, history) |
| history = [] |
|
|
| next_round = True |
| current_round = 0 |
| enable_summary = False |
| last_status = {} |
| token_overflow = False |
|
|
| if self.enable_checker: |
| checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation)) |
|
|
| try: |
| while next_round and current_round < max_round: |
| current_round += 1 |
| last_outputs = [] |
| if last_outputs: |
| function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream( |
| last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt, |
| message_for_call_agent=message, call_agent=call_agent, |
| call_agent_level=call_agent_level, temperature=temperature) |
| history.extend(current_gradio_history) |
|
|
| if special_tool_call == 'Finish': |
| yield history |
| next_round = False |
| conversation.extend(function_call_messages) |
| return function_call_messages[0]['content'] |
|
|
| if special_tool_call in ['RequireClarification', 'DirectResponse']: |
| last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.") |
| history.append(ChatMessage(role="assistant", content=last_msg.content)) |
| yield history |
| next_round = False |
| return last_msg.content |
|
|
| if (self.enable_summary or token_overflow) and not call_agent: |
| enable_summary = True |
| last_status = self.function_result_summary( |
| conversation, status=last_status, enable_summary=enable_summary) |
|
|
| if function_call_messages: |
| conversation.extend(function_call_messages) |
| yield history |
| else: |
| next_round = False |
| return ''.join(last_outputs).replace("</s>", "") |
|
|
| if self.enable_checker: |
| good_status, wrong_info = checker.check_conversation() |
| if not good_status: |
| logger.warning("Checker error: %s", wrong_info) |
| break |
|
|
| |
| tools = [] if has_clinical_data else picked_tools_prompt |
| last_outputs_str, token_overflow = self.llm_infer( |
| messages=conversation, temperature=temperature, tools=tools, |
| max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True) |
|
|
| if last_outputs_str is None: |
| if self.force_finish: |
| last_outputs_str = self.get_answer_based_on_unfinished_reasoning( |
| conversation, temperature, max_new_tokens, max_token) |
| history.append(ChatMessage(role="assistant", content=last_outputs_str.strip())) |
| yield history |
| return last_outputs_str |
| error_msg = "Token limit exceeded." |
| history.append(ChatMessage(role="assistant", content=error_msg)) |
| yield history |
| return error_msg |
|
|
| last_thought = last_outputs_str.split("[TOOL_CALLS]")[0] |
| for msg in history: |
| if msg.metadata: |
| msg.metadata['status'] = 'done' |
|
|
| if '[FinalAnswer]' in last_thought: |
| parts = last_thought.split('[FinalAnswer]', 1) |
| final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "") |
| history.append(ChatMessage(role="assistant", content=final_thought.strip())) |
| yield history |
| history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) |
| yield history |
| else: |
| history.append(ChatMessage(role="assistant", content=last_thought)) |
| yield history |
|
|
| last_outputs.append(last_outputs_str) |
|
|
| if next_round and self.force_finish: |
| last_outputs_str = self.get_answer_based_on_unfinished_reasoning( |
| conversation, temperature, max_new_tokens, max_token) |
| parts = last_outputs_str.split('[FinalAnswer]', 1) |
| final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "") |
| history.append(ChatMessage(role="assistant", content=final_thought.strip())) |
| yield history |
| history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) |
| yield history |
|
|
| except Exception as e: |
| logger.error("Exception in run_gradio_chat: %s", e, exc_info=True) |
| error_msg = f"Error: {e}" |
| history.append(ChatMessage(role="assistant", content=error_msg)) |
| yield history |
| if self.force_finish: |
| last_outputs_str = self.get_answer_based_on_unfinished_reasoning( |
| conversation, temperature, max_new_tokens, max_token) |
| parts = last_outputs_str.split('[FinalAnswer]', 1) |
| final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "") |
| history.append(ChatMessage(role="assistant", content=final_thought.strip())) |
| yield history |
| history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip())) |
| yield history |
| return error_msg |