from typing import Dict, List, Optional, Tuple from swift.llm.template import split_str_parts_by def calculate_loss_scale(query: str, response: str, response_loss_scale_map: Dict[str, list], query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]: """Calculate the loss scale by splitting the agent response. This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf Agent response format: ```text Thought: you should always think about what to do Action: the action to take, should be one of the above tools[fire_recognition, fire_alert, call_police, call_fireman] Action Input: the input to the action Observation: the result of the action ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) Thought: I now know the final answer Final Answer: the final answer to the original input question ``` Returns: A tuple of agent response parts and their weights. """ # query loss scale map if query_loss_scale_map is not None: for key in query_loss_scale_map.keys(): if key in query: if isinstance(query_loss_scale_map[key], (float, int)): query_loss_scale_map[key] = [query_loss_scale_map[key]] loss_scale_value = query_loss_scale_map[key][0] return [response], [float(loss_scale_value)] delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2] if delimiters: agent_parts = split_str_parts_by(response, delimiters) else: regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1] agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True) weights = [] agent_content = [] for c in agent_parts: if c['key'] in response_loss_scale_map: loss_scale = response_loss_scale_map[c['key']] assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}' if len(loss_scale) == 1: weights += loss_scale agent_content.append(c['content']) else: weights += loss_scale agent_content += [c['key'], c['content']] else: weights.append(1.) agent_content.append(c['content']) return agent_content, weights