Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +53 -23
src/txagent/txagent.py
CHANGED
|
@@ -13,6 +13,7 @@ from gradio import ChatMessage
|
|
| 13 |
from .toolrag import ToolRAGModel
|
| 14 |
import torch
|
| 15 |
import logging
|
|
|
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -26,15 +27,15 @@ class TxAgent:
|
|
| 26 |
enable_finish=True,
|
| 27 |
enable_rag=True,
|
| 28 |
enable_summary=False,
|
| 29 |
-
init_rag_num=2,
|
| 30 |
-
step_rag_num=4,
|
| 31 |
summary_mode='step',
|
| 32 |
summary_skip_last_k=0,
|
| 33 |
summary_context_length=None,
|
| 34 |
force_finish=True,
|
| 35 |
avoid_repeat=True,
|
| 36 |
seed=None,
|
| 37 |
-
enable_checker=False,
|
| 38 |
enable_chat=False,
|
| 39 |
additional_default_tools=None):
|
| 40 |
self.model_name = model_name
|
|
@@ -78,7 +79,7 @@ class TxAgent:
|
|
| 78 |
if model_name:
|
| 79 |
self.model_name = model_name
|
| 80 |
|
| 81 |
-
self.model = LLM(model=self.model_name, dtype="float16")
|
| 82 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
| 83 |
self.tokenizer = self.model.get_tokenizer()
|
| 84 |
logger.info("Model %s loaded successfully", self.model_name)
|
|
@@ -101,16 +102,17 @@ class TxAgent:
|
|
| 101 |
|
| 102 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
| 103 |
picked_tools_prompt = []
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
| 114 |
return picked_tools_prompt, call_agent_level
|
| 115 |
|
| 116 |
def initialize_conversation(self, message, conversation=None, history=None):
|
|
@@ -129,7 +131,7 @@ class TxAgent:
|
|
| 129 |
|
| 130 |
def tool_RAG(self, message=None, picked_tool_names=None,
|
| 131 |
existing_tools_prompt=None, rag_num=4, return_call_result=False):
|
| 132 |
-
extra_factor = 10
|
| 133 |
if picked_tool_names is None:
|
| 134 |
picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
|
| 135 |
|
|
@@ -148,10 +150,10 @@ class TxAgent:
|
|
| 148 |
if self.enable_finish:
|
| 149 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
| 150 |
logger.debug("Finish tool added")
|
| 151 |
-
if call_agent:
|
| 152 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
| 153 |
logger.debug("CallAgent tool added")
|
| 154 |
-
elif self.enable_rag:
|
| 155 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
| 156 |
logger.debug("Tool_RAG tool added")
|
| 157 |
if self.additional_default_tools:
|
|
@@ -301,7 +303,7 @@ class TxAgent:
|
|
| 301 |
return output
|
| 302 |
|
| 303 |
def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
|
| 304 |
-
max_token: int, max_round: int =
|
| 305 |
logger.debug("Starting multistep agent for message: %s", message[:100])
|
| 306 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
| 307 |
call_agent, call_agent_level, message)
|
|
@@ -317,6 +319,10 @@ class TxAgent:
|
|
| 317 |
if self.enable_checker:
|
| 318 |
checker = ReasoningTraceChecker(message, conversation)
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
while next_round and current_round < max_round:
|
| 321 |
current_round += 1
|
| 322 |
if last_outputs:
|
|
@@ -349,9 +355,11 @@ class TxAgent:
|
|
| 349 |
logger.warning("Checker error: %s", wrong_info)
|
| 350 |
break
|
| 351 |
|
|
|
|
|
|
|
| 352 |
last_outputs = []
|
| 353 |
last_outputs_str, token_overflow = self.llm_infer(
|
| 354 |
-
messages=conversation, temperature=temperature, tools=
|
| 355 |
max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
|
| 356 |
if last_outputs_str is None:
|
| 357 |
if self.force_finish:
|
|
@@ -374,7 +382,22 @@ class TxAgent:
|
|
| 374 |
m['content'] for m in messages[-3:] if m['role'] == 'assistant'
|
| 375 |
][:2]
|
| 376 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
return None
|
| 379 |
|
| 380 |
def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
|
|
@@ -407,7 +430,7 @@ class TxAgent:
|
|
| 407 |
output = model.generate(prompt, sampling_params=sampling_params)
|
| 408 |
output = output[0].outputs[0].text
|
| 409 |
logger.debug("Inference output: %s", output[:100])
|
| 410 |
-
torch.cuda.empty_cache()
|
| 411 |
if check_token_status:
|
| 412 |
return output, False
|
| 413 |
return output
|
|
@@ -544,7 +567,7 @@ Summarize the function responses in one sentence with all necessary information.
|
|
| 544 |
|
| 545 |
def run_gradio_chat(self, message: str, history: list, temperature: float,
|
| 546 |
max_new_tokens: int, max_token: int, call_agent: bool,
|
| 547 |
-
conversation: gr.State, max_round: int =
|
| 548 |
call_agent_level: int = 0, sub_agent_task: str = None,
|
| 549 |
uploaded_files: list = None):
|
| 550 |
logger.debug("Chat started, message: %s", message[:100])
|
|
@@ -555,6 +578,11 @@ Summarize the function responses in one sentence with all necessary information.
|
|
| 555 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
| 556 |
return
|
| 557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
| 559 |
call_agent, call_agent_level, message)
|
| 560 |
conversation = self.initialize_conversation(
|
|
@@ -612,8 +640,10 @@ Summarize the function responses in one sentence with all necessary information.
|
|
| 612 |
logger.warning("Checker error: %s", wrong_info)
|
| 613 |
break
|
| 614 |
|
|
|
|
|
|
|
| 615 |
last_outputs_str, token_overflow = self.llm_infer(
|
| 616 |
-
messages=conversation, temperature=temperature, tools=
|
| 617 |
max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
|
| 618 |
|
| 619 |
if last_outputs_str is None:
|
|
|
|
| 13 |
from .toolrag import ToolRAGModel
|
| 14 |
import torch
|
| 15 |
import logging
|
| 16 |
+
from difflib import SequenceMatcher
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 27 |
enable_finish=True,
|
| 28 |
enable_rag=True,
|
| 29 |
enable_summary=False,
|
| 30 |
+
init_rag_num=2,
|
| 31 |
+
step_rag_num=4,
|
| 32 |
summary_mode='step',
|
| 33 |
summary_skip_last_k=0,
|
| 34 |
summary_context_length=None,
|
| 35 |
force_finish=True,
|
| 36 |
avoid_repeat=True,
|
| 37 |
seed=None,
|
| 38 |
+
enable_checker=False,
|
| 39 |
enable_chat=False,
|
| 40 |
additional_default_tools=None):
|
| 41 |
self.model_name = model_name
|
|
|
|
| 79 |
if model_name:
|
| 80 |
self.model_name = model_name
|
| 81 |
|
| 82 |
+
self.model = LLM(model=self.model_name, dtype="float16")
|
| 83 |
self.chat_template = Template(self.model.get_tokenizer().chat_template)
|
| 84 |
self.tokenizer = self.model.get_tokenizer()
|
| 85 |
logger.info("Model %s loaded successfully", self.model_name)
|
|
|
|
| 102 |
|
| 103 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
| 104 |
picked_tools_prompt = []
|
| 105 |
+
# Only add Finish tool unless prompt explicitly requires Tool_RAG or CallAgent
|
| 106 |
+
if "use external tools" not in message.lower():
|
| 107 |
+
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
|
| 108 |
+
else:
|
| 109 |
+
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=call_agent)
|
| 110 |
+
if call_agent:
|
| 111 |
+
call_agent_level += 1
|
| 112 |
+
if call_agent_level >= 2:
|
| 113 |
+
call_agent = False
|
| 114 |
+
if self.enable_rag:
|
| 115 |
+
picked_tools_prompt += self.tool_RAG(message=message, rag_num=self.init_rag_num)
|
| 116 |
return picked_tools_prompt, call_agent_level
|
| 117 |
|
| 118 |
def initialize_conversation(self, message, conversation=None, history=None):
|
|
|
|
| 131 |
|
| 132 |
def tool_RAG(self, message=None, picked_tool_names=None,
|
| 133 |
existing_tools_prompt=None, rag_num=4, return_call_result=False):
|
| 134 |
+
extra_factor = 10
|
| 135 |
if picked_tool_names is None:
|
| 136 |
picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
|
| 137 |
|
|
|
|
| 150 |
if self.enable_finish:
|
| 151 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
|
| 152 |
logger.debug("Finish tool added")
|
| 153 |
+
if call_agent and "use external tools" in self.prompt_multi_step.lower():
|
| 154 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
|
| 155 |
logger.debug("CallAgent tool added")
|
| 156 |
+
elif self.enable_rag and "use external tools" in self.prompt_multi_step.lower():
|
| 157 |
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
|
| 158 |
logger.debug("Tool_RAG tool added")
|
| 159 |
if self.additional_default_tools:
|
|
|
|
| 303 |
return output
|
| 304 |
|
| 305 |
def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
|
| 306 |
+
max_token: int, max_round: int = 3, call_agent=False, call_agent_level=0):
|
| 307 |
logger.debug("Starting multistep agent for message: %s", message[:100])
|
| 308 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
| 309 |
call_agent, call_agent_level, message)
|
|
|
|
| 319 |
if self.enable_checker:
|
| 320 |
checker = ReasoningTraceChecker(message, conversation)
|
| 321 |
|
| 322 |
+
# Check if message contains clinical findings
|
| 323 |
+
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
| 324 |
+
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
| 325 |
+
|
| 326 |
while next_round and current_round < max_round:
|
| 327 |
current_round += 1
|
| 328 |
if last_outputs:
|
|
|
|
| 355 |
logger.warning("Checker error: %s", wrong_info)
|
| 356 |
break
|
| 357 |
|
| 358 |
+
# Skip tool calls if clinical data is present
|
| 359 |
+
tools = [] if has_clinical_data else picked_tools_prompt
|
| 360 |
last_outputs = []
|
| 361 |
last_outputs_str, token_overflow = self.llm_infer(
|
| 362 |
+
messages=conversation, temperature=temperature, tools=tools,
|
| 363 |
max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
|
| 364 |
if last_outputs_str is None:
|
| 365 |
if self.force_finish:
|
|
|
|
| 382 |
m['content'] for m in messages[-3:] if m['role'] == 'assistant'
|
| 383 |
][:2]
|
| 384 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
| 385 |
+
# Enhance deduplication with similarity check
|
| 386 |
+
unique_sentences = []
|
| 387 |
+
for msg in assistant_messages:
|
| 388 |
+
sentences = msg.split('. ')
|
| 389 |
+
for s in sentences:
|
| 390 |
+
if not s:
|
| 391 |
+
continue
|
| 392 |
+
is_unique = True
|
| 393 |
+
for seen_s in unique_sentences:
|
| 394 |
+
if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
|
| 395 |
+
is_unique = False
|
| 396 |
+
break
|
| 397 |
+
if is_unique:
|
| 398 |
+
unique_sentences.append(s)
|
| 399 |
+
forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
|
| 400 |
+
return [NoRepeatSentenceProcessor(forbidden_ids, 10)] # Increased penalty
|
| 401 |
return None
|
| 402 |
|
| 403 |
def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
|
|
|
|
| 430 |
output = model.generate(prompt, sampling_params=sampling_params)
|
| 431 |
output = output[0].outputs[0].text
|
| 432 |
logger.debug("Inference output: %s", output[:100])
|
| 433 |
+
torch.cuda.empty_cache()
|
| 434 |
if check_token_status:
|
| 435 |
return output, False
|
| 436 |
return output
|
|
|
|
| 567 |
|
| 568 |
def run_gradio_chat(self, message: str, history: list, temperature: float,
|
| 569 |
max_new_tokens: int, max_token: int, call_agent: bool,
|
| 570 |
+
conversation: gr.State, max_round: int = 3, seed: int = None,
|
| 571 |
call_agent_level: int = 0, sub_agent_task: str = None,
|
| 572 |
uploaded_files: list = None):
|
| 573 |
logger.debug("Chat started, message: %s", message[:100])
|
|
|
|
| 578 |
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
| 579 |
return
|
| 580 |
|
| 581 |
+
# Check if message contains clinical findings
|
| 582 |
+
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
| 583 |
+
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
| 584 |
+
call_agent = call_agent and not has_clinical_data # Disable CallAgent for clinical data
|
| 585 |
+
|
| 586 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
| 587 |
call_agent, call_agent_level, message)
|
| 588 |
conversation = self.initialize_conversation(
|
|
|
|
| 640 |
logger.warning("Checker error: %s", wrong_info)
|
| 641 |
break
|
| 642 |
|
| 643 |
+
# Skip tool calls if clinical data is present
|
| 644 |
+
tools = [] if has_clinical_data else picked_tools_prompt
|
| 645 |
last_outputs_str, token_overflow = self.llm_infer(
|
| 646 |
+
messages=conversation, temperature=temperature, tools=tools,
|
| 647 |
max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
|
| 648 |
|
| 649 |
if last_outputs_str is None:
|