|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
from swift.llm.template import split_str_parts_by |
|
|
|
|
|
|
|
|
def calculate_loss_scale(query: str, |
|
|
response: str, |
|
|
response_loss_scale_map: Dict[str, list], |
|
|
query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]: |
|
|
"""Calculate the loss scale by splitting the agent response. |
|
|
|
|
|
This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf |
|
|
|
|
|
Agent response format: |
|
|
|
|
|
```text |
|
|
Thought: you should always think about what to do |
|
|
Action: the action to take, should be one of the above tools[fire_recognition, |
|
|
fire_alert, call_police, call_fireman] |
|
|
Action Input: the input to the action |
|
|
Observation: the result of the action |
|
|
... (this Thought/Action/Action Input/Observation can be repeated zero or more times) |
|
|
Thought: I now know the final answer |
|
|
Final Answer: the final answer to the original input question |
|
|
``` |
|
|
Returns: |
|
|
A tuple of agent response parts and their weights. |
|
|
""" |
|
|
|
|
|
if query_loss_scale_map is not None: |
|
|
for key in query_loss_scale_map.keys(): |
|
|
if key in query: |
|
|
if isinstance(query_loss_scale_map[key], (float, int)): |
|
|
query_loss_scale_map[key] = [query_loss_scale_map[key]] |
|
|
loss_scale_value = query_loss_scale_map[key][0] |
|
|
return [response], [float(loss_scale_value)] |
|
|
delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2] |
|
|
if delimiters: |
|
|
agent_parts = split_str_parts_by(response, delimiters) |
|
|
else: |
|
|
regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1] |
|
|
agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True) |
|
|
weights = [] |
|
|
agent_content = [] |
|
|
for c in agent_parts: |
|
|
if c['key'] in response_loss_scale_map: |
|
|
loss_scale = response_loss_scale_map[c['key']] |
|
|
assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}' |
|
|
if len(loss_scale) == 1: |
|
|
weights += loss_scale |
|
|
agent_content.append(c['content']) |
|
|
else: |
|
|
weights += loss_scale |
|
|
agent_content += [c['key'], c['content']] |
|
|
else: |
|
|
weights.append(1.) |
|
|
agent_content.append(c['content']) |
|
|
return agent_content, weights |
|
|
|