File size: 15,711 Bytes
f7defa6
b94bd29
f7defa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7e1552
f7defa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b11116
f7defa6
6b11116
 
f7defa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b11116
 
 
 
 
 
 
 
 
 
f7defa6
6b11116
f7defa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350fec2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
---
license: apache-2.0
language:
- en
pipeline_tag: text-generation
library_name: transformers
---
# TD-Llama-General

## TL;DR
**TD (ToolDial)-Llama-General** is a version of the LLaMA3.1-8B-Instruct model that has been instruction-tuned on the ToolDial dataset in the ReAct format (Thought-Action-Observation). Unlike the original ToolDial paper, dialogue states are removed from the reasoning trace, and the API call payload is concatenated within the reasoning step for Call actions.

## This model can be used as,
- Baseline model for tool calling (function calling), and multi-turn dialogue agent task.
- Warm up starting model for SFT, RL for Tool Agent training.

**[Caution]**
- To use the model for inference, it is important to carefully follow the **Input Format Example** and use the instruction provided under **Instruction for TD-Llama-General**.
- Refer the **Example Code** below.

**[Model Summary]**
- Trained with Q-lora quantization. Our quantization configuration is in the **Example Code**.
- Trained for 1 epoch with Adam-8bit optimizer with learning rate 0.00001 and beta 0.9 to 0.995.
- We only provide trained LoRa adapter. Follow the **Example Code** and you can use it easily.


**Input Format Example**
```
User: I need to analyze the intraday performance of a cryptocurrency pair and compare it with the intraday data of an FX currency pair I have already retrieved. Can you provide the detailed intraday time series data for the cryptocurrency pair?
System:
- Thought: The user seems to have intent. I will call the retriever.
- Action:retriever_call
- Observation: {'retriever_status': {'retriever_call': 'true', 'retrieved_api': {'Finance|quotient|/crypto/intraday': 0.71542144, 'Finance|eod_historical_data|Intraday Historical Data API': 0.67002374, 'Finance|quotient|/forex/daily': 0.6112553, 'Finance|quotient|/forex/intraday': 0.5370315, 'Finance|real_time_finance_data|Currency Time Series': 0.50981194}}, 'api_documentation': {'api_name': 'Finance|quotient|/crypto/intraday', 'api_description': 'Return intraday time series (Date, Open, High, Low, Close, Volume) given the input parameters.', 'required_parameters': [{'input_parameter_name': 'interval', 'description': 'the time interval for retrieving intraday market data, specifying the duration over which the intraday time series data (Date, Open, High, Low, Close, Volume) is requested.'}, {'input_parameter_name': 'to', 'description': "The end date for the intraday time series query is the specific date that marks the conclusion of the requested intraday data period for various assets, including cryptocurrencies, as defined by the API's functionality."}, {'input_parameter_name': 'symbol', 'description': "The symbol of the cryptocurrency pair for which intraday market data is requested, representing the specific digital assets being analyzed in the context of the API's functionality."}, {'input_parameter_name': 'from', 'description': 'The start date for the intraday time series is the specific date from which the intraday market data, including open, high, low, close, and volume, will be retrieved for various assets such as cryptocurrencies.'}], 'optional_parameters': []}}
- Thought: Based on results of the retriever, I need to call Finance|quotient|/crypto/intraday. To call it, I need to ask 'interval', 'to', 'symbol', and 'from' to the user.
- Action:request
- Observation: {}
- Message: Please tell me the interval, to, symbol, and from for the cryptocurrency pair.
User: Certainly! The interval is 15min, the to date is 2023-10-01 15:00, the symbol is ETH-USD, and the from date is 2023-10-01 09:00.
System:
- Thought: Based on the user's response, now I can call the API Finance|quotient|/crypto/intraday.
- Action:call{'api_name': 'Finance|quotient|/crypto/intraday', 'required_parameters': {'interval': '15min', 'to': '2023-10-01 15:00', 'symbol': 'ETH-USD', 'from': '2023-10-01 09:00'}, 'optional_parameters': {}}
- Observation: {'call_result': {'api_name': 'Finance|quotient|/crypto/intraday', 'result': {'symbol': 'ETH-USD', 'data': [{'date': '2023-10-01 09:00', 'open': 1800.0, 'high': 1820.0, 'low': 1790.0, 'close': 1815.0, 'volume': 1500}, {'date': '2023-10-01 09:15', 'open': 1815.0, 'high': 1830.0, 'low': 1805.0, 'close': 1820.0, 'volume': 1200}, {'date': '2023-10-01 09:30', 'open': 1820.0, 'high': 1840.0, 'low': 1810.0, 'close': 1835.0, 'volume': 1300}, {'date': '2023-10-01 09:45', 'open': 1835.0, 'high': 1850.0, 'low': 1825.0, 'close': 1845.0, 'volume': 1400}, {'date': '2023-10-01 10:00', 'open': 1845.0, 'high': 1860.0, 'low': 1835.0, 'close': 1855.0, 'volume': 1600}, {'date': '2023-10-01 10:15', 'open': 1855.0, 'high': 1870.0, 'low': 1845.0, 'close': 1865.0, 'volume': 1500}, {'date': '2023-10-01 10:30', 'open': 1865.0, 'high': 1880.0, 'low': 1855.0, 'close': 1875.0, 'volume': 1200}, {'date': '2023-10-01 10:45', 'open': 1875.0, 'high': 1890.0, 'low': 1865.0, 'close': 1885.0, 'volume': 1300}, {'date': '2023-10-01 11:00', 'open': 1885.0, 'high': 1900.0, 'low': 1875.0, 'close': 1895.0, 'volume': 1400}, {'date': '2023-10-01 11:15', 'open': 1895.0, 'high': 1910.0, 'low': 1885.0, 'close': 1905.0, 'volume': 1600}, {'date': '2023-10-01 11:30', 'open': 1905.0, 'high': 1920.0, 'low': 1895.0, 'close': 1915.0, 'volume': 1500}, {'date': '2023-10-01 11:45', 'open': 1915.0, 'high': 1930.0, 'low': 1905.0, 'close': 1925.0, 'volume': 1200}, {'date': '2023-10-01 12:00', 'open': 1925.0, 'high': 1940.0, 'low': 1915.0, 'close': 1935.0, 'volume': 1300}, {'date': '2023-10-01 12:15', 'open': 1935.0, 'high': 1950.0, 'low': 1925.0, 'close': 1945.0, 'volume': 1400}, {'date': '2023-10-01 12:30', 'open': 1945.0, 'high': 1960.0, 'low': 1935.0, 'close': 1955.0, 'volume': 1600}, {'date': '2023-10-01 12:45', 'open': 1955.0, 'high': 1970.0, 'low': 1945.0, 'close': 1965.0, 'volume': 1500}, {'date': '2023-10-01 13:00', 'open': 1965.0, 'high': 1980.0, 'low': 1955.0, 'close': 1975.0, 'volume': 1200}, {'date': '2023-10-01 13:15', 'open': 1975.0, 'high': 1990.0, 'low': 1965.0, 'close': 1985.0, 'volume': 1300}, {'date': '2023-10-01 13:30', 'open': 1985.0, 'high': 2000.0, 'low': 1975.0, 'close': 1995.0, 'volume': 1400}, {'date': '2023-10-01 13:45', 'open': 1995.0, 'high': 2010.0, 'low': 1985.0, 'close': 2005.0, 'volume': 1600}, {'date': '2023-10-01 14:00', 'open': 2005.0, 'high': 2020.0, 'low': 1995.0, 'close': 2015.0, 'volume': 1500}, {'date': '2023-10-01 14:15', 'open': 2015.0, 'high': 2030.0, 'low': 2005.0, 'close': 2025.0, 'volume': 1200}, {'date': '2023-10-01 14:30', 'open': 2025.0, 'high': 2040.0, 'low': 2015.0, 'close': 2035.0, 'volume': 1300}, {'date': '2023-10-01 14:45', 'open': 2035.0, 'high': 2050.0, 'low': 2025.0, 'close': 2045.0, 'volume': 1400}], 'interval': 15, 'from': '2023-10-01 09:00', 'to': '2023-10-01 15:00'}}}
- Thought: Based on the results of Finance|quotient|/crypto/intraday, I can respond to the user.
- Action:response
- Observation: {}
- Message: Here is the detailed intraday time series data for the cryptocurrency pair ETH-USD from 2023-10-01 09:00 to 2023-10-01 15:00 with a 15-minute interval. The data includes open, high, low, close prices, and volume for each interval.
User: Thank you for the detailed information. Goodbye!
- Thought: The user has said goodbye. I will respond politely.
- Action:system_bye
- Observation: {}
- Message: You're welcome! Have a great day! Goodbye!
```

**Instruction for TD-Llama-General**
```
Instruction:
You will be provided with part of a dialogue between the user and the system. In this dialogue, the user is requesting information from the system, and the system will make an API call to retrieve the necessary information.
Your task is to generate an appropriate reasoning, and action based on the dialogue history and the user's most recent utterance. Additionally, generate an appropriate response message for the user's last utterance.

---------------Refer the internal rule----------------------

<Retriever status rules>
1. The system selects the API with the highest score from among the APIs in the retriever status that have a score of 0.6 or higher, which is suitable for processing the user's query.
2. If no APIs score is higher than 0.6, system cannot confirm the API to call.

<Action list>
The actions that the system can take are as follows.
- Request: In a situation where the API to be executed has been determined, ask the user questions to gather information on the input parameters.
- Response: Reply to user's request based on the result of API call.
- Clarify: If user's query is vague, re-ask user to get intent specifically. If there is no API in the most recently retrieved results with a score above 0.5, "clarify" is required.
- Suggest: Making a suggestion for an unclear user's intent and ask-ing whether it satisfies the user. If there is an API in the most recently retrieved results with a score above 0.5 but none exceeding 0.6, a 'suggest' action is required.
- Response fail: Notify the user that the system cannot execute the request due to insufficient information.
- System bye: System says goodbye to user politely.
- Call: Call the API with collected information from user or else and don't reply to user yet.
(Unlike other actions, Call must be generated by concatenating it with the API call format. The required format is as follows.)
Call{'api_name': 'API1', 'required_parameters': {'param1': 'value1', 'param2': 'value2'}, 'optional_parameters': {'param3': 'value3'}
- Retriever call: Call the retriever to find proper API to satisfy user's requests. The system should call the retriever in the following two situations:
1. When the user specifies a requirement, and the system needs to search for an API to fulfill it.
2. When the user does not provide input parameters required for an API call, and the system needs to search for another API to obtain those parameters.

----------------------------------

Based on the rules above, you will get an input as below:
Input:
<Dialogue history>
User: <User's last utterance>

And here is the format for what you should return:
- Thought: <System's thought process. You generate here>
- Action: <System's action. You generate here. Choose one from <Action list>>
- Observation: <This will be provided externally.>
- Thought: <System's thought process. You generate here, refering the dialogue history and the observation above>
- Action: <System's action. You generate here, refering the dialogue history and the observation above. Choose one from <Action list>>
- Observation: <This will be provided externally.>
- Message: <System's message>

Output Format Rules:
1. Multiple reasoning steps (Thought-Action pairs) may be present in the output if required
2. The call format for the Call action must be constructed by referring to the given API documentation and dialogue history.

Now, generate appropriate reasoning trace and responding message to user.

Dialogue History:
```

**Example Code**

```python
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from prompts import overall_evaluation_gen_prompt ## same prompt on "Instruction for TD-Llama-General"
import json,re,torch
from peft import PeftModel, PeftConfig

def normalize_action(text):
    lower_text = text.lower()
    normalized_text = re.sub(r'[^a-z]', '', lower_text)
    return normalized_text.replace("action","").replace("api","")

def extract_observation(label):
    lines = label.split("\n")
    return_label = {}
    for idx,sentence in enumerate(lines):
        if "- Action:" in sentence:
            current_obs = lines[idx+1].replace("- Observation: ","")
            current_action = normalize_action(sentence.replace("- Action:","").strip())
            if current_action in return_label:
                return_label[current_action].insert(0,current_obs)
            else:
                return_label[current_action] = [current_obs]
    return return_label

def overall_form(dialogue_idx,full_dialogue,tokenizer_inf,is_val=False):
    prompt = ""
    lines = full_dialogue['dialogue'].split("\n")
    for idx,sentence in enumerate(lines):
        if "- Dialogue State: " in sentence:
            continue

        if "- Action:call" in sentence:
            if "- Dialogue State: " in lines[idx-2]:
                prompt+=sentence+str(eval(lines[idx-2].replace("- Dialogue State: ",""))['api_status'])+"\n"
            else:
                find_ds_idx=1
                while "- Dialogue State: " not in lines[idx+find_ds_idx]:
                    find_ds_idx+=1
                prompt+=sentence+str(eval(lines[idx+find_ds_idx].replace("- Dialogue State: ",""))['api_status'])+"\n"
        else:
            prompt+=sentence+"\n"
    
    lines = prompt.split("\n")
    label = []
    prompt = ""
    for idx,sentence in enumerate(lines):
        prompt+=sentence+"\n"
        if "User: " in sentence:
            current_dial = {"dial":prompt,"label":""}
            tmp_idx = 1
            while "User: " not in lines[idx+tmp_idx]:
                current_dial["label"]+=lines[idx+tmp_idx]+"\n"
                tmp_idx+=1
                if idx+tmp_idx==len(lines):
                    break
            current_dial['dial']+="System:"
            current_dial['label'] = current_dial['label'].replace("System:\n","")
            if is_val: ## test
                chat_eval = [{"role": "user", "content":overall_evaluation_gen_prompt+prompt}] #
                template_chat = tokenizer_inf.apply_chat_template(chat_eval, tokenize=False, add_generation_prompt=True)
                label.append({"dialogue_idx":dialogue_idx,
                              "dial":template_chat,
                              "label":current_dial['label'],
                              "obs":extract_observation(current_dial['label'])})
            else: ## train
                history_chat = [{"role": "user", "content": overall_evaluation_gen_prompt+prompt}, {"role": "assistant", "content": ""}]
                chat = [{"role": "user", "content": overall_evaluation_gen_prompt+prompt}, {"role": "assistant", "content": current_dial['label']}]
                template_history = tokenizer_inf.apply_chat_template(history_chat, tokenize=False, add_generation_prompt=False)
                template_chat = tokenizer_inf.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
                
                label.append({
                    "dialogue_idx":dialogue_idx,
                    "history": template_history.replace("\n<|eot_id|>", "\n"),
                    "dial": template_chat,
                    "label": current_dial['label']})
    return label

device = "cuda:0"
base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    ),
    device_map={"": device},
)
tokenizer = AutoTokenizer.from_pretrained("HOLILAB/td-llama-general")
model = PeftModel.from_pretrained(base_model, "HOLILAB/td-llama-general")

with open("test_dialogue_overall_obs.json",'r') as f: # https://drive.google.com/drive/folders/17IASNvCcJRkHlg2tzMPRlpPbo_MlHs94?hl=ko
    dial_list = json.load(f)

sample_dialogue_idx = 200
turn_idx = 2

input_prompt=overall_form(sample_dialogue_idx,dial_list[sample_dialogue_idx],tokenizer,is_val=True)[turn_idx]['dial']

inputs = tokenizer(input_prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_new_tokens=1000,
    pad_token_id=tokenizer.eos_token_id,
    do_sample=False
)
```