CPS-Test-Mobile / src /txagent /txagent.py
Ali2206's picture
Update src/txagent/txagent.py
e0669ce verified
raw
history blame
35.1 kB
import gradio as gr
import os
import sys
import json
import gc
import numpy as np
from vllm import LLM, SamplingParams
from jinja2 import Template
from typing import List
import types
from tooluniverse import ToolUniverse
from gradio import ChatMessage
from .toolrag import ToolRAGModel
import torch
import logging
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
class TxAgent:
def __init__(self, model_name,
rag_model_name,
tool_files_dict=None,
enable_finish=True,
enable_rag=True,
enable_summary=False,
init_rag_num=2,
step_rag_num=4,
summary_mode='step',
summary_skip_last_k=0,
summary_context_length=None,
force_finish=True,
avoid_repeat=True,
seed=None,
enable_checker=False,
enable_chat=False,
additional_default_tools=None):
self.model_name = model_name
self.tokenizer = None
self.terminators = None
self.rag_model_name = rag_model_name
self.tool_files_dict = tool_files_dict
self.model = None
self.rag_model = ToolRAGModel(rag_model_name)
self.tooluniverse = None
self.prompt_multi_step = "You are a medical assistant solving clinical oversight issues step-by-step using provided tools."
self.self_prompt = "Follow instructions precisely."
self.chat_prompt = "You are a helpful assistant for clinical queries."
self.enable_finish = enable_finish
self.enable_rag = enable_rag
self.enable_summary = enable_summary
self.summary_mode = summary_mode
self.summary_skip_last_k = summary_skip_last_k
self.summary_context_length = summary_context_length
self.init_rag_num = init_rag_num
self.step_rag_num = step_rag_num
self.force_finish = force_finish
self.avoid_repeat = avoid_repeat
self.seed = seed
self.enable_checker = enable_checker
self.additional_default_tools = additional_default_tools
logger.debug("TxAgent initialized with parameters: %s", self.__dict__)
def init_model(self):
self.load_models()
self.load_tooluniverse()
self.load_tool_desc_embedding()
def print_self_values(self):
for attr, value in self.__dict__.items():
logger.debug("%s: %s", attr, value)
def load_models(self, model_name=None):
if model_name is not None and model_name == self.model_name:
return f"The model {model_name} is already loaded."
if model_name:
self.model_name = model_name
self.model = LLM(model=self.model_name, dtype="float16")
self.chat_template = Template(self.model.get_tokenizer().chat_template)
self.tokenizer = self.model.get_tokenizer()
logger.info("Model %s loaded successfully", self.model_name)
return f"Model {self.model_name} loaded successfully."
def load_tooluniverse(self):
self.tooluniverse = ToolUniverse(tool_files=self.tool_files_dict)
self.tooluniverse.load_tools()
special_tools = self.tooluniverse.prepare_tool_prompts(
self.tooluniverse.tool_category_dicts["special_tools"])
self.special_tools_name = [tool['name'] for tool in special_tools]
logger.debug("ToolUniverse loaded with %d special tools", len(self.special_tools_name))
def load_tool_desc_embedding(self):
self.rag_model.load_tool_desc_embedding(self.tooluniverse)
logger.debug("Tool description embeddings loaded")
def rag_infer(self, query, top_k=5):
return self.rag_model.rag_infer(query, top_k)
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
picked_tools_prompt = []
# Only add Finish tool unless prompt explicitly requires Tool_RAG or CallAgent
if "use external tools" not in message.lower():
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
else:
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=call_agent)
if call_agent:
call_agent_level += 1
if call_agent_level >= 2:
call_agent = False
if self.enable_rag:
picked_tools_prompt += self.tool_RAG(message=message, rag_num=self.init_rag_num)
return picked_tools_prompt, call_agent_level
def initialize_conversation(self, message, conversation=None, history=None):
if conversation is None:
conversation = []
conversation = self.set_system_prompt(conversation, self.prompt_multi_step)
if history:
conversation.extend(
{"role": h['role'], "content": h['content']}
for h in history if h['role'] in ['user', 'assistant']
)
conversation.append({"role": "user", "content": message})
logger.debug("Conversation initialized with %d messages", len(conversation))
return conversation
def tool_RAG(self, message=None, picked_tool_names=None,
existing_tools_prompt=None, rag_num=4, return_call_result=False):
extra_factor = 10
if picked_tool_names is None:
picked_tool_names = self.rag_infer(message, top_k=rag_num * extra_factor)
picked_tool_names = [
tool for tool in picked_tool_names
if tool not in self.special_tools_name
][:rag_num]
picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
logger.debug("RAG selected %d tools: %s", len(picked_tool_names), picked_tool_names)
if return_call_result:
return picked_tools_prompt, picked_tool_names
return picked_tools_prompt
def add_special_tools(self, tools, call_agent=False):
if self.enable_finish:
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
logger.debug("Finish tool added")
if call_agent and "use external tools" in self.prompt_multi_step.lower():
tools.append(self.tooluniverse.get_one_tool_by_one_name('CallAgent', return_prompt=True))
logger.debug("CallAgent tool added")
elif self.enable_rag and "use external tools" in self.prompt_multi_step.lower():
tools.append(self.tooluniverse.get_one_tool_by_one_name('Tool_RAG', return_prompt=True))
logger.debug("Tool_RAG tool added")
if self.additional_default_tools:
for tool_name in self.additional_default_tools:
tool_prompt = self.tooluniverse.get_one_tool_by_one_name(tool_name, return_prompt=True)
if tool_prompt:
tools.append(tool_prompt)
logger.debug("%s tool added", tool_name)
return tools
def add_finish_tools(self, tools):
tools.append(self.tooluniverse.get_one_tool_by_one_name('Finish', return_prompt=True))
logger.debug("Finish tool added")
return tools
def set_system_prompt(self, conversation, sys_prompt):
if not conversation:
conversation.append({"role": "system", "content": sys_prompt})
else:
conversation[0] = {"role": "system", "content": sys_prompt}
return conversation
def run_function_call(self, fcall_str, return_message=False,
existing_tools_prompt=None, message_for_call_agent=None,
call_agent=False, call_agent_level=None, temperature=None):
function_call_json, message = self.tooluniverse.extract_function_call_json(
fcall_str, return_message=return_message, verbose=False)
call_results = []
special_tool_call = ''
if function_call_json:
for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
logger.debug("Tool Call: %s", func)
if func["name"] == 'Finish':
special_tool_call = 'Finish'
break
elif func["name"] == 'Tool_RAG':
new_tools_prompt, call_result = self.tool_RAG(
message=message, existing_tools_prompt=existing_tools_prompt,
rag_num=self.step_rag_num, return_call_result=True)
existing_tools_prompt += new_tools_prompt
elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
solution_plan = func['arguments']['solution']
full_message = (
message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
)
call_result = self.run_multistep_agent(
full_message, temperature=temperature, max_new_tokens=512,
max_token=2048, call_agent=False, call_agent_level=call_agent_level)
call_result = call_result.split('[FinalAnswer]')[-1].strip() if call_result else "⚠️ No content from sub-agent."
else:
call_result = self.tooluniverse.run_one_function(func)
call_id = self.tooluniverse.call_id_gen()
func["call_id"] = call_id
logger.debug("Tool Call Result: %s", call_result)
call_results.append({
"role": "tool",
"content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
})
else:
call_results.append({
"role": "tool",
"content": json.dumps({"content": "Invalid function call format."})
})
revised_messages = [{
"role": "assistant",
"content": message.strip() if message else "",
"tool_calls": json.dumps(function_call_json)
}] + call_results
return revised_messages, existing_tools_prompt, special_tool_call
def run_function_call_stream(self, fcall_str, return_message=False,
existing_tools_prompt=None, message_for_call_agent=None,
call_agent=False, call_agent_level=None, temperature=None,
return_gradio_history=True):
function_call_json, message = self.tooluniverse.extract_function_call_json(
fcall_str, return_message=return_message, verbose=False)
call_results = []
special_tool_call = ''
gradio_history = [] if return_gradio_history else None
if function_call_json:
for func in function_call_json if isinstance(function_call_json, list) else [function_call_json]:
if func["name"] == 'Finish':
special_tool_call = 'Finish'
break
elif func["name"] == 'Tool_RAG':
new_tools_prompt, call_result = self.tool_RAG(
message=message, existing_tools_prompt=existing_tools_prompt,
rag_num=self.step_rag_num, return_call_result=True)
existing_tools_prompt += new_tools_prompt
elif func["name"] == 'DirectResponse':
call_result = func['arguments']['response']
special_tool_call = 'DirectResponse'
elif func["name"] == 'RequireClarification':
call_result = func['arguments']['unclear_question']
special_tool_call = 'RequireClarification'
elif func["name"] == 'CallAgent' and call_agent and call_agent_level < 2:
solution_plan = func['arguments']['solution']
full_message = (
message_for_call_agent + "\nFollow this plan: " + str(solution_plan)
)
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
call_result = yield from self.run_gradio_chat(
full_message, history=[], temperature=temperature,
max_new_tokens=512, max_token=2048, call_agent=False,
call_agent_level=call_agent_level, conversation=None,
sub_agent_task=sub_agent_task)
call_result = call_result.split('[FinalAnswer]')[-1] if call_result else "⚠️ No content from sub-agent."
else:
call_result = self.tooluniverse.run_one_function(func)
call_id = self.tooluniverse.call_id_gen()
func["call_id"] = call_id
call_results.append({
"role": "tool",
"content": json.dumps({"tool_name": func["name"], "content": call_result, "call_id": call_id})
})
if return_gradio_history and func["name"] != 'Finish':
title = f"{'🧰' if func['name'] == 'Tool_RAG' else '⚒️'} {func['name']}"
gradio_history.append(ChatMessage(
role="assistant", content=str(call_result),
metadata={"title": title, "log": str(func['arguments'])}
))
else:
call_results.append({
"role": "tool",
"content": json.dumps({"content": "Invalid function call format."})
})
revised_messages = [{
"role": "assistant",
"content": message.strip() if message else "",
"tool_calls": json.dumps(function_call_json)
}] + call_results
return revised_messages, existing_tools_prompt, special_tool_call, gradio_history
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token):
if conversation[-1]['role'] == 'assistant':
conversation.append(
{'role': 'tool', 'content': 'Errors occurred; provide final answer with current info.'})
finish_tools_prompt = self.add_finish_tools([])
output = self.llm_infer(
messages=conversation, temperature=temperature, tools=finish_tools_prompt,
output_begin_string='[FinalAnswer]', max_new_tokens=max_new_tokens, max_token=max_token)
logger.debug("Unfinished reasoning output: %s", output)
return output
def run_multistep_agent(self, message: str, temperature: float, max_new_tokens: int,
max_token: int, max_round: int = 3, call_agent=False, call_agent_level=0):
logger.debug("Starting multistep agent for message: %s", message[:100])
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
call_agent, call_agent_level, message)
conversation = self.initialize_conversation(message)
outputs = []
last_outputs = []
next_round = True
current_round = 0
token_overflow = False
enable_summary = False
last_status = {}
if self.enable_checker:
checker = ReasoningTraceChecker(message, conversation)
# Check if message contains clinical findings
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
while next_round and current_round < max_round:
current_round += 1
if last_outputs:
function_call_messages, picked_tools_prompt, special_tool_call = self.run_function_call(
last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
message_for_call_agent=message, call_agent=call_agent,
call_agent_level=call_agent_level, temperature=temperature)
if special_tool_call == 'Finish':
next_round = False
conversation.extend(function_call_messages)
content = function_call_messages[0]['content']
return content.split('[FinalAnswer]')[-1] if content else "❌ No content after Finish."
if (self.enable_summary or token_overflow) and not call_agent:
enable_summary = True
last_status = self.function_result_summary(
conversation, status=last_status, enable_summary=enable_summary)
if function_call_messages:
conversation.extend(function_call_messages)
outputs.append(tool_result_format(function_call_messages))
else:
next_round = False
return ''.join(last_outputs).replace("</s>", "")
if self.enable_checker:
good_status, wrong_info = checker.check_conversation()
if not good_status:
logger.warning("Checker error: %s", wrong_info)
break
# Skip tool calls if clinical data is present
tools = [] if has_clinical_data else picked_tools_prompt
last_outputs = []
last_outputs_str, token_overflow = self.llm_infer(
messages=conversation, temperature=temperature, tools=tools,
max_new_tokens=max_new_tokens, max_token=max_token, check_token_status=True)
if last_outputs_str is None:
if self.force_finish:
return self.get_answer_based_on_unfinished_reasoning(
conversation, temperature, max_new_tokens, max_token)
return "❌ Token limit exceeded."
last_outputs.append(last_outputs_str)
if current_round >= max_round:
logger.warning("Max rounds exceeded")
if self.force_finish:
return self.get_answer_based_on_unfinished_reasoning(
conversation, temperature, max_new_tokens, max_token)
return None
def build_logits_processor(self, messages, llm):
tokenizer = llm.get_tokenizer()
if self.avoid_repeat and len(messages) > 2:
assistant_messages = [
m['content'] for m in messages[-3:] if m['role'] == 'assistant'
][:2]
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
# Enhance deduplication with similarity check
unique_sentences = []
for msg in assistant_messages:
sentences = msg.split('. ')
for s in sentences:
if not s:
continue
is_unique = True
for seen_s in unique_sentences:
if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
is_unique = False
break
if is_unique:
unique_sentences.append(s)
forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
return [NoRepeatSentenceProcessor(forbidden_ids, 10)] # Increased penalty
return None
def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
max_new_tokens=512, max_token=2048, skip_special_tokens=True,
model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
if model is None:
model = self.model
logits_processor = self.build_logits_processor(messages, model)
sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_new_tokens,
seed=seed if seed is not None else self.seed,
logits_processors=logits_processor
)
prompt = self.chat_template.render(messages=messages, tools=tools, add_generation_prompt=True)
if output_begin_string:
prompt += output_begin_string
if check_token_status and max_token:
num_input_tokens = len(self.tokenizer.encode(prompt, return_tensors="pt")[0])
if num_input_tokens > max_token:
torch.cuda.empty_cache()
gc.collect()
logger.info("Token overflow: %d > %d", num_input_tokens, max_token)
return None, True
logger.debug("Input tokens: %d", num_input_tokens)
output = model.generate(prompt, sampling_params=sampling_params)
output = output[0].outputs[0].text
logger.debug("Inference output: %s", output[:100])
torch.cuda.empty_cache()
if check_token_status:
return output, False
return output
def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
logger.debug("Starting self agent")
conversation = self.set_system_prompt([], self.self_prompt)
conversation.append({"role": "user", "content": message})
return self.llm_infer(messages=conversation, temperature=temperature,
max_new_tokens=max_new_tokens, max_token=max_token)
def run_chat_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
logger.debug("Starting chat agent")
conversation = self.set_system_prompt([], self.chat_prompt)
conversation.append({"role": "user", "content": message})
return self.llm_infer(messages=conversation, temperature=temperature,
max_new_tokens=max_new_tokens, max_token=max_token)
def run_format_agent(self, message: str, answer: str, temperature: float, max_new_tokens: int, max_token: int):
logger.debug("Starting format agent")
if '[FinalAnswer]' in answer:
possible_final_answer = answer.split("[FinalAnswer]")[-1]
elif "\n\n" in answer:
possible_final_answer = answer.split("\n\n")[-1]
else:
possible_final_answer = answer.strip()
if len(possible_final_answer) >= 1 and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
return possible_final_answer[0]
elif len(possible_final_answer) > 1 and possible_final_answer[1] == ':' and possible_final_answer[0] in ['A', 'B', 'C', 'D', 'E']:
return possible_final_answer[0]
conversation = self.set_system_prompt(
[], "Transform the answer to a single letter: 'A', 'B', 'C', 'D', or 'E'.")
conversation.append({"role": "user", "content": f"Original: {message}\nAnswer: {answer}\nFinal answer (letter):"})
return self.llm_infer(messages=conversation, temperature=temperature,
max_new_tokens=max_new_tokens, max_token=max_token)
def run_summary_agent(self, thought_calls: str, function_response: str,
temperature: float, max_new_tokens: int, max_token: int):
logger.debug("Starting summary agent")
prompt = f"""Thought and function calls: {thought_calls}
Function responses: {function_response}
Summarize the function responses in one sentence with all necessary information."""
conversation = [{"role": "user", "content": prompt}]
output = self.llm_infer(messages=conversation, temperature=temperature,
max_new_tokens=max_new_tokens, max_token=max_token)
if '[' in output:
output = output.split('[')[0]
return output
def function_result_summary(self, input_list, status, enable_summary):
if 'tool_call_step' not in status:
status['tool_call_step'] = 0
if 'step' not in status:
status['step'] = 0
status['step'] += 1
for idx in range(len(input_list)):
pos_id = len(input_list) - idx - 1
if input_list[pos_id]['role'] == 'assistant' and 'tool_calls' in input_list[pos_id]:
if 'Tool_RAG' in str(input_list[pos_id]['tool_calls']):
status['tool_call_step'] += 1
break
if not enable_summary:
return status
if 'summarized_index' not in status:
status['summarized_index'] = 0
if 'summarized_step' not in status:
status['summarized_step'] = 0
if 'previous_length' not in status:
status['previous_length'] = 0
if 'history' not in status:
status['history'] = []
status['history'].append(
self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k)
idx = status['summarized_index']
function_response = ''
this_thought_calls = None
while idx < len(input_list):
if (self.summary_mode == 'step' and status['summarized_step'] < status['step'] - status['tool_call_step'] - self.summary_skip_last_k) or \
(self.summary_mode == 'length' and status['previous_length'] > self.summary_context_length):
if input_list[idx]['role'] == 'assistant':
if 'Tool_RAG' in str(input_list[idx]['tool_calls']):
this_thought_calls = None
else:
if function_response:
status['summarized_step'] += 1
result_summary = self.run_summary_agent(
thought_calls=this_thought_calls, function_response=function_response,
temperature=0.1, max_new_tokens=512, max_token=2048)
input_list.insert(
last_call_idx + 1, {'role': 'tool', 'content': result_summary})
status['summarized_index'] = last_call_idx + 2
idx += 1
last_call_idx = idx
this_thought_calls = input_list[idx]['content'] + input_list[idx]['tool_calls']
function_response = ''
elif input_list[idx]['role'] == 'tool' and this_thought_calls:
function_response += input_list[idx]['content']
del input_list[idx]
idx -= 1
else:
break
idx += 1
if function_response:
status['summarized_step'] += 1
result_summary = self.run_summary_agent(
thought_calls=this_thought_calls, function_response=function_response,
temperature=0.1, max_new_tokens=512, max_token=2048)
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
for tool_call in tool_calls:
del tool_call['call_id']
input_list[last_call_idx]['tool_calls'] = json.dumps(tool_calls)
input_list.insert(
last_call_idx + 1, {'role': 'tool', 'content': result_summary})
status['summarized_index'] = last_call_idx + 2
return status
def update_parameters(self, **kwargs):
updated_attributes = {}
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
updated_attributes[key] = value
logger.debug("Updated parameters: %s", updated_attributes)
return updated_attributes
def run_gradio_chat(self, message: str, history: list, temperature: float,
max_new_tokens: int, max_token: int, call_agent: bool,
conversation: gr.State, max_round: int = 3, seed: int = None,
call_agent_level: int = 0, sub_agent_task: str = None,
uploaded_files: list = None):
logger.debug("Chat started, message: %s", message[:100])
if not message or len(message.strip()) < 5:
yield "Please provide a valid message or upload files to analyze."
return
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
return
# Check if message contains clinical findings
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
call_agent = call_agent and not has_clinical_data # Disable CallAgent for clinical data
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
call_agent, call_agent_level, message)
conversation = self.initialize_conversation(
message, conversation, history)
history = []
next_round = True
current_round = 0
enable_summary = False
last_status = {}
token_overflow = False
if self.enable_checker:
checker = ReasoningTraceChecker(message, conversation, init_index=len(conversation))
try:
while next_round and current_round < max_round:
current_round += 1
last_outputs = []
if last_outputs:
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
message_for_call_agent=message, call_agent=call_agent,
call_agent_level=call_agent_level, temperature=temperature)
history.extend(current_gradio_history)
if special_tool_call == 'Finish':
yield history
next_round = False
conversation.extend(function_call_messages)
return function_call_messages[0]['content']
if special_tool_call in ['RequireClarification', 'DirectResponse']:
last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
history.append(ChatMessage(role="assistant", content=last_msg.content))
yield history
next_round = False
return last_msg.content
if (self.enable_summary or token_overflow) and not call_agent:
enable_summary = True
last_status = self.function_result_summary(
conversation, status=last_status, enable_summary=enable_summary)
if function_call_messages:
conversation.extend(function_call_messages)
yield history
else:
next_round = False
return ''.join(last_outputs).replace("</s>", "")
if self.enable_checker:
good_status, wrong_info = checker.check_conversation()
if not good_status:
logger.warning("Checker error: %s", wrong_info)
break
# Skip tool calls if clinical data is present
tools = [] if has_clinical_data else picked_tools_prompt
last_outputs_str, token_overflow = self.llm_infer(
messages=conversation, temperature=temperature, tools=tools,
max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
if last_outputs_str is None:
if self.force_finish:
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
conversation, temperature, max_new_tokens, max_token)
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
yield history
return last_outputs_str
error_msg = "Token limit exceeded."
history.append(ChatMessage(role="assistant", content=error_msg))
yield history
return error_msg
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
for msg in history:
if msg.metadata:
msg.metadata['status'] = 'done'
if '[FinalAnswer]' in last_thought:
parts = last_thought.split('[FinalAnswer]', 1)
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
yield history
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
yield history
else:
history.append(ChatMessage(role="assistant", content=last_thought))
yield history
last_outputs.append(last_outputs_str)
if next_round and self.force_finish:
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
conversation, temperature, max_new_tokens, max_token)
parts = last_outputs_str.split('[FinalAnswer]', 1)
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
yield history
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
yield history
except Exception as e:
logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
error_msg = f"Error: {e}"
history.append(ChatMessage(role="assistant", content=error_msg))
yield history
if self.force_finish:
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
conversation, temperature, max_new_tokens, max_token)
parts = last_outputs_str.split('[FinalAnswer]', 1)
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
yield history
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
yield history
return error_msg