Update README.md
Browse files
README.md
CHANGED
|
@@ -21,19 +21,16 @@ It has been trained using [TRL](https://github.com/huggingface/trl).
|
|
| 21 |
|
| 22 |
## Quick start
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
|
| 26 |
````python
|
|
|
|
| 27 |
import json
|
| 28 |
|
| 29 |
-
import torch
|
| 30 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 31 |
from jinja2 import Template
|
| 32 |
|
| 33 |
-
model_name = "plaguss/Llama-3.2-1B-Instruct-APIGen-FC-v0.1"
|
| 34 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True)
|
| 35 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 36 |
-
|
| 37 |
SYSTEM_PROMPT = """
|
| 38 |
You are an expert in composing functions. You are given a question and a set of possible functions.
|
| 39 |
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
@@ -58,6 +55,63 @@ Please answer the following query:
|
|
| 58 |
{{ query }}
|
| 59 |
""".lstrip())
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
get_weather_api = {
|
| 62 |
"name": "get_weather",
|
| 63 |
"description": "Get the current weather for a location",
|
|
@@ -93,45 +147,147 @@ search_api = {
|
|
| 93 |
}
|
| 94 |
}
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
query = "What's the weather like in New York in fahrenheit?"
|
| 99 |
|
| 100 |
-
|
| 101 |
|
| 102 |
-
messages=[
|
| 103 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 104 |
-
{ 'role': 'user', 'content': user_prompt}
|
| 105 |
-
]
|
| 106 |
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 107 |
|
| 108 |
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 109 |
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
import re
|
| 113 |
-
matches = re.findall(pattern, result, re.DOTALL)
|
| 114 |
-
response = json.loads(matches[0])
|
| 115 |
# [{'name': 'get_weather', 'arguments': {'location': 'New York', 'unit': 'fahrenheit'}}]
|
| 116 |
````
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
Example response with no tools available
|
| 119 |
|
| 120 |
```python
|
| 121 |
-
|
| 122 |
|
| 123 |
query = "What's the weather like in New York in fahrenheit?"
|
| 124 |
|
| 125 |
-
|
| 126 |
|
| 127 |
-
messages=[
|
| 128 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 129 |
-
{ 'role': 'user', 'content': user_prompt}
|
| 130 |
-
]
|
| 131 |
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 132 |
|
| 133 |
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 134 |
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
|
|
|
| 135 |
# 'The query cannot be answered, no tools were provided.'
|
| 136 |
```
|
| 137 |
|
|
@@ -151,23 +307,21 @@ cut_number = {
|
|
| 151 |
}
|
| 152 |
}
|
| 153 |
|
| 154 |
-
|
| 155 |
|
| 156 |
query = "What's the weather like in New York in fahrenheit?"
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
messages=[
|
| 161 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 162 |
-
{ 'role': 'user', 'content': user_prompt}
|
| 163 |
-
]
|
| 164 |
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 165 |
|
| 166 |
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 167 |
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
|
|
|
| 168 |
# "The query cannot be answered with the provided tools. The query lacks the parameters required by the function. Please provide the parameters, and I'll be happy to assist."
|
| 169 |
```
|
| 170 |
|
|
|
|
|
|
|
| 171 |
## Training procedure
|
| 172 |
|
| 173 |
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/plaguss/huggingface/runs/dw9q43g4)
|
|
|
|
| 21 |
|
| 22 |
## Quick start
|
| 23 |
|
| 24 |
+
See different examples of using the model:
|
| 25 |
+
|
| 26 |
+
<details><summary> Click to see `prepare_messages` function </summary>
|
| 27 |
|
| 28 |
````python
|
| 29 |
+
from typing import Optional
|
| 30 |
import json
|
| 31 |
|
|
|
|
|
|
|
| 32 |
from jinja2 import Template
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
SYSTEM_PROMPT = """
|
| 35 |
You are an expert in composing functions. You are given a question and a set of possible functions.
|
| 36 |
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
|
|
| 55 |
{{ query }}
|
| 56 |
""".lstrip())
|
| 57 |
|
| 58 |
+
def prepare_messages(query: str, tools: Optional[dict[str, any]] = None) -> list[dict[str, str]]:
|
| 59 |
+
"""Prepare the system and user messages for the given query and tools.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
query: The query to be answered.
|
| 63 |
+
tools: The tools available to the user. Defaults to None, in which case if a
|
| 64 |
+
list without content will be passed to the model.
|
| 65 |
+
"""
|
| 66 |
+
if tools is None:
|
| 67 |
+
tools = []
|
| 68 |
+
|
| 69 |
+
return [
|
| 70 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 71 |
+
{"role": "user", "content": prompt.render(tools=json.dumps(tools), query=query)}
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
````
|
| 75 |
+
|
| 76 |
+
</details>
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
<details><summary> Click to see `parse_response` function </summary>
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
import re
|
| 83 |
+
import json
|
| 84 |
+
|
| 85 |
+
def parse_response(text: str) -> str | dict[str, any]:
|
| 86 |
+
"""Parses a response from the model, returning either the
|
| 87 |
+
parsed list with the tool calls parsed, or the
|
| 88 |
+
model thought or response if couldn't generate one.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
text: Response from the model.
|
| 92 |
+
"""
|
| 93 |
+
pattern = r"<tool_call>(.*?)</tool_call>"
|
| 94 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 95 |
+
if matches:
|
| 96 |
+
return json.loads(matches[0])
|
| 97 |
+
return text
|
| 98 |
+
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
</details>
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
Example of *simple* function call:
|
| 105 |
+
|
| 106 |
+
````python
|
| 107 |
+
import torch
|
| 108 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 109 |
+
|
| 110 |
+
model_name = "plaguss/Llama-3.2-1B-Instruct-APIGen-FC-v0.1"
|
| 111 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True)
|
| 112 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
get_weather_api = {
|
| 116 |
"name": "get_weather",
|
| 117 |
"description": "Get the current weather for a location",
|
|
|
|
| 147 |
}
|
| 148 |
}
|
| 149 |
|
| 150 |
+
available_tools = [get_weather_api, search_api]
|
| 151 |
|
| 152 |
query = "What's the weather like in New York in fahrenheit?"
|
| 153 |
|
| 154 |
+
messages = prepare_messages(query, tools=available_tools)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 157 |
|
| 158 |
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 159 |
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
| 160 |
|
| 161 |
+
response = parse_response(result)
|
|
|
|
|
|
|
|
|
|
| 162 |
# [{'name': 'get_weather', 'arguments': {'location': 'New York', 'unit': 'fahrenheit'}}]
|
| 163 |
````
|
| 164 |
|
| 165 |
+
<details><summary> Click to see an example of *parallel* function call: </summary>
|
| 166 |
+
|
| 167 |
+
```python
|
| 168 |
+
available_tools = [{"name": "spotify.play", "description": "Play specific tracks from a given artist for a specific time duration.", "parameters": {"type": "dict", "properties": {"artist": {"type": "string", "description": "The artist whose songs you want to play."}, "duration": {"type": "integer", "description": "The duration for which the songs should be played, in minutes."}}, "required": ["artist", "duration"]}}]
|
| 169 |
+
query = "Play songs from the artists Taylor Swift and Maroon 5, with a play time of 20 minutes and 15 minutes respectively, on Spotify."
|
| 170 |
+
|
| 171 |
+
messages = prepare_messages(query, tools=available_tools)
|
| 172 |
+
|
| 173 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 174 |
+
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 175 |
+
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
| 176 |
+
|
| 177 |
+
response = parse_response(result)
|
| 178 |
+
# [{'name': 'spotify.play', 'arguments': {'artist': 'Taylor Swift', 'duration': 20}}, {'name': 'spotify.play', 'arguments': {'artist': 'Maroon 5', 'duration': 15}}]
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
</details>
|
| 182 |
+
|
| 183 |
+
<details><summary> Click to see an example of *multiple* function call: </summary>
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
available_tools = [{"name": "country_info.largest_city", "description": "Fetch the largest city of a specified country.", "parameters": {"type": "dict", "properties": {"country": {"type": "string", "description": "Name of the country."}}, "required": ["country"]}}, {"name": "country_info.capital", "description": "Fetch the capital city of a specified country.", "parameters": {"type": "dict", "properties": {"country": {"type": "string", "description": "Name of the country."}}, "required": ["country"]}}, {"name": "country_info.population", "description": "Fetch the current population of a specified country.", "parameters": {"type": "dict", "properties": {"country": {"type": "string", "description": "Name of the country."}}, "required": ["country"]}}]
|
| 187 |
+
query = "What is the capital of Brazil?"
|
| 188 |
+
|
| 189 |
+
messages = prepare_messages(query, tools=available_tools)
|
| 190 |
+
|
| 191 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 192 |
+
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 193 |
+
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
| 194 |
+
|
| 195 |
+
response = parse_response(result)
|
| 196 |
+
# [{'name': 'country_info.capital', 'arguments': {'country': 'Brazil'}}]
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
</details>
|
| 200 |
+
|
| 201 |
+
<details><summary> Click to see an example of *parallel multiple* function call: </summary>
|
| 202 |
+
|
| 203 |
+
```python
|
| 204 |
+
available_tools = [{"name": "math_toolkit.sum_of_multiples", "description": "Find the sum of all multiples of specified numbers within a specified range.", "parameters": {"type": "dict", "properties": {"lower_limit": {"type": "integer", "description": "The start of the range (inclusive)."}, "upper_limit": {"type": "integer", "description": "The end of the range (inclusive)."}, "multiples": {"type": "array", "items": {"type": "integer"}, "description": "The numbers to find multiples of."}}, "required": ["lower_limit", "upper_limit", "multiples"]}}, {"name": "math_toolkit.product_of_primes", "description": "Find the product of the first n prime numbers.", "parameters": {"type": "dict", "properties": {"count": {"type": "integer", "description": "The number of prime numbers to multiply together."}}, "required": ["count"]}}]
|
| 205 |
+
query = "Find the sum of all the multiples of 3 and 5 between 1 and 1000. Also find the product of the first five prime numbers."
|
| 206 |
+
|
| 207 |
+
messages = prepare_messages(query, tools=available_tools)
|
| 208 |
+
|
| 209 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 210 |
+
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 211 |
+
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
| 212 |
+
|
| 213 |
+
response = parse_response(result)
|
| 214 |
+
# [{'name': 'math_toolkit.sum_of_multiples', 'arguments': {'lower_limit': 1, 'upper_limit': 1000, 'multiples': [3, 5]}}, {'name': 'math_toolkit.product_of_primes', 'arguments': {'count': 5}}]
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
</details>
|
| 218 |
+
|
| 219 |
+
<details><summary> Click to see an example of *multi-turn* function call: </summary>
|
| 220 |
+
|
| 221 |
+
```python
|
| 222 |
+
|
| 223 |
+
get_weather_api = {
|
| 224 |
+
"name": "get_weather",
|
| 225 |
+
"description": "Get the current weather for a location",
|
| 226 |
+
"parameters": {
|
| 227 |
+
"type": "object",
|
| 228 |
+
"properties": {
|
| 229 |
+
"location": {
|
| 230 |
+
"type": "string",
|
| 231 |
+
"description": "The city and state, e.g. San Francisco, New York"
|
| 232 |
+
},
|
| 233 |
+
"unit": {
|
| 234 |
+
"type": "string",
|
| 235 |
+
"enum": ["celsius", "fahrenheit"],
|
| 236 |
+
"description": "The unit of temperature to return"
|
| 237 |
+
}
|
| 238 |
+
},
|
| 239 |
+
"required": ["location"]
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
available_tools = [get_weather_api]
|
| 244 |
+
|
| 245 |
+
query = "What's the weather like in Madrid in celsius?"
|
| 246 |
+
|
| 247 |
+
messages = prepare_messages(query, tools=available_tools)
|
| 248 |
+
|
| 249 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 250 |
+
|
| 251 |
+
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 252 |
+
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
| 253 |
+
|
| 254 |
+
response = parse_response(result)
|
| 255 |
+
|
| 256 |
+
# 2nd turn
|
| 257 |
+
conversation_history = messages.copy()
|
| 258 |
+
conversation_history.append({"role": "assistant", "content": json.dumps(response)})
|
| 259 |
+
|
| 260 |
+
new_query = "And in Edinburgh in celsius?"
|
| 261 |
+
|
| 262 |
+
new_messages = prepare_messages(new_query, tools=available_tools, conversation_history=conversation_history)
|
| 263 |
+
|
| 264 |
+
inputs = tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 265 |
+
|
| 266 |
+
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 267 |
+
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
| 268 |
+
|
| 269 |
+
response = parse_response(result)
|
| 270 |
+
# [{'name': 'get_weather', 'arguments': {'location': 'Edinburgh', 'unit': 'celsius'}}]
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
</details>
|
| 274 |
+
|
| 275 |
+
<details><summary> Click to see an example of *irrelevance* function call: </summary>
|
| 276 |
+
|
| 277 |
Example response with no tools available
|
| 278 |
|
| 279 |
```python
|
| 280 |
+
available_tools = []
|
| 281 |
|
| 282 |
query = "What's the weather like in New York in fahrenheit?"
|
| 283 |
|
| 284 |
+
messages = prepare_messages(query, tools=available_tools)
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 287 |
|
| 288 |
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 289 |
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
| 290 |
+
response = parse_response(result)
|
| 291 |
# 'The query cannot be answered, no tools were provided.'
|
| 292 |
```
|
| 293 |
|
|
|
|
| 307 |
}
|
| 308 |
}
|
| 309 |
|
| 310 |
+
available_tools = [cut_number]
|
| 311 |
|
| 312 |
query = "What's the weather like in New York in fahrenheit?"
|
| 313 |
|
| 314 |
+
messages = prepare_messages(query, tools=available_tools)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
| 316 |
|
| 317 |
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
| 318 |
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
| 319 |
+
response = parse_response(result)
|
| 320 |
# "The query cannot be answered with the provided tools. The query lacks the parameters required by the function. Please provide the parameters, and I'll be happy to assist."
|
| 321 |
```
|
| 322 |
|
| 323 |
+
</details>
|
| 324 |
+
|
| 325 |
## Training procedure
|
| 326 |
|
| 327 |
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/plaguss/huggingface/runs/dw9q43g4)
|