suyash94 commited on
Commit
d46efc1
·
verified ·
1 Parent(s): 3050708

Upload folder using huggingface_hub

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,80 @@
1
  ---
2
- title: Slm Function Calling
3
- emoji: 🐠
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.2.0
8
  app_file: app.py
9
- pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SLM Function Calling
3
+ emoji: "\U0001F697"
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: "4.44.0"
8
  app_file: app.py
9
+ license: mit
10
+ tags:
11
+ - function-calling
12
+ - gpt2
13
+ - lora
14
+ - car-control
15
+ - nlp
16
  ---
17
 
18
+ # SLM Function Calling - Car Control Demo
19
+
20
+ Convert natural language commands into structured function calls using a fine-tuned GPT-2 model with LoRA.
21
+
22
+ ## Demo
23
+
24
+ Try commands like:
25
+ - "Set the temperature to 22 degrees for the driver"
26
+ - "Turn up the heat"
27
+ - "Navigate to Central Park"
28
+ - "Play jazz music at volume 7"
29
+ - "Lock all the doors"
30
+
31
+ ## Model Details
32
+
33
+ | Property | Value |
34
+ |----------|-------|
35
+ | Base Model | GPT-2 (124M parameters) |
36
+ | Fine-tuning | LoRA (rank=32, alpha=32) |
37
+ | Training Data | ~156K car control command samples |
38
+ | Functions | 18 car control functions |
39
+ | Input | Natural language command |
40
+ | Output | Structured function call JSON |
41
+
42
+ ## Function Categories
43
+
44
+ The model supports 18 functions across these categories:
45
+
46
+ - **Climate Control:** set_temperature, adjust_temperature, set_fan_speed, adjust_fan_speed
47
+ - **Comfort:** adjust_seat, control_window
48
+ - **Wipers & Defroster:** set_wiper_speed, adjust_wiper_speed, activate_defroster
49
+ - **Engine & Security:** start_engine, lock_doors
50
+ - **Entertainment:** play_music
51
+ - **Navigation:** set_navigation_destination
52
+ - **Lighting:** toggle_headlights, control_ambient_lighting
53
+ - **Driving:** set_cruise_control, toggle_sport_mode
54
+ - **Maintenance:** check_battery_health
55
+
56
+ ## Output Format
57
+
58
+ The model generates function calls in this format:
59
+
60
+ ```
61
+ <functioncall> {"name": "function_name", "arguments": "{'param': value}"} <|im_end|>
62
+ ```
63
+
64
+ ## Training
65
+
66
+ The model was trained using:
67
+ - ChatML format with system/user/assistant messages
68
+ - Label masking (loss computed only on assistant response)
69
+ - LoRA adapters targeting attention layers (c_attn, c_proj)
70
+
71
+ ## Limitations
72
+
73
+ - Only trained on car control domain
74
+ - May produce incorrect outputs for ambiguous or out-of-domain queries
75
+ - Best results with clear, specific commands
76
+
77
+ ## Links
78
+
79
+ - [GitHub Repository](https://github.com/suyash94/slm-function-calling)
80
+ - [Model Weights on HuggingFace](https://huggingface.co/suyash94/gpt2-fc-adapter)
__pycache__/inferencer.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SLM Function Calling - Gradio App for HuggingFace Spaces."""
2
+
3
+ import os
4
+
5
+ import gradio as gr
6
+ from inferencer import Inferencer
7
+
8
+ # Configuration
9
+ REPO_ID = os.environ.get("HF_MODEL_REPO", "suyash94/gpt2-fc-adapter")
10
+ LOCAL_DIR = os.environ.get("LOCAL_CHECKPOINT_DIR", None)
11
+ BASE_MODEL = os.environ.get("BASE_MODEL", "gpt2")
12
+
13
+ # Initialize inferencer (loads model on startup)
14
+ print("=" * 60)
15
+ print("SLM Function Calling - Car Control Demo")
16
+ print("=" * 60)
17
+ if LOCAL_DIR:
18
+ print(f"Loading from local: {LOCAL_DIR}")
19
+ inferencer = Inferencer(local_dir=LOCAL_DIR, base_model=BASE_MODEL)
20
+ else:
21
+ print(f"Loading from HuggingFace Hub: {REPO_ID}")
22
+ inferencer = Inferencer(repo_id=REPO_ID, base_model=BASE_MODEL)
23
+ print("Model ready!")
24
+
25
+
26
+ def predict_function_call(command: str) -> tuple[str, dict]:
27
+ """Predict function call from user command.
28
+
29
+ :param command: User's natural language command
30
+ :return: Tuple of (raw response, parsed function call dict)
31
+ """
32
+ if not command or not command.strip():
33
+ return "", {"info": "Please enter a command"}
34
+
35
+ result = inferencer.predict(command.strip())
36
+
37
+ raw_response = result["response"]
38
+ parsed = result["parsed"]
39
+
40
+ return raw_response, parsed
41
+
42
+
43
+ # Example commands covering all 18 functions
44
+ EXAMPLE_COMMANDS = [
45
+ # Climate Control
46
+ ["Set the temperature to 22 degrees for the driver"],
47
+ ["Turn up the heat"],
48
+ ["Set the fan to high"],
49
+ # Comfort
50
+ ["Move my seat forward"],
51
+ ["Close all the windows"],
52
+ # Wipers & Defroster
53
+ ["Set wipers to medium"],
54
+ ["Speed up the wipers"],
55
+ ["Turn on the defroster for 15 minutes"],
56
+ # Engine & Doors
57
+ ["Start the engine"],
58
+ ["Lock all the doors"],
59
+ # Entertainment
60
+ ["Play some jazz music at volume 7"],
61
+ # Navigation
62
+ ["Navigate to Central Park, New York"],
63
+ # Lighting
64
+ ["Turn on the headlights"],
65
+ ["Set ambient lighting to blue at intensity 8"],
66
+ # Driving
67
+ ["Set cruise control to 80"],
68
+ ["Activate sport mode"],
69
+ # Maintenance
70
+ ["Check the battery health with history"],
71
+ ]
72
+
73
+
74
+ # Main description (left column)
75
+ LEFT_DESCRIPTION = """
76
+ # SLM Function Calling
77
+
78
+ A **small language model (GPT-2, 124M params)** fine-tuned to convert natural language commands into structured function calls for car control.
79
+
80
+ ## What This Model Does
81
+
82
+ Given a command like *"Set the temperature to 22 degrees"*, outputs:
83
+
84
+ ```json
85
+ {"fn_name": "set_temperature", "properties": {"temperature": 22}}
86
+ ```
87
+
88
+ ## Available Functions (18 Total)
89
+
90
+ | Category | Functions |
91
+ |----------|-----------|
92
+ | **Climate** | `set_temperature`, `adjust_temperature`, `set_fan_speed`, `adjust_fan_speed` |
93
+ | **Comfort** | `adjust_seat`, `control_window` |
94
+ | **Wipers** | `set_wiper_speed`, `adjust_wiper_speed`, `activate_defroster` |
95
+ | **Engine** | `start_engine`, `lock_doors` |
96
+ | **Media** | `play_music` |
97
+ | **Nav** | `set_navigation_destination` |
98
+ | **Lights** | `toggle_headlights`, `control_ambient_lighting` |
99
+ | **Driving** | `set_cruise_control`, `toggle_sport_mode` |
100
+ | **Maintenance** | `check_battery_health` |
101
+
102
+ ## Example Commands
103
+
104
+ - *"Set the temperature to 25 degrees"*
105
+ - *"Turn up the heat"* / *"Make it cooler"*
106
+ - *"Move my seat forward"*
107
+ - *"Open all windows"*
108
+ - *"Turn on the wipers"*
109
+ - *"Start the car remotely"*
110
+ - *"Lock the doors"*
111
+ - *"Play jazz at volume 7"*
112
+ - *"Navigate to Central Park"*
113
+ - *"Turn on headlights"*
114
+ - *"Set ambient lighting to blue"*
115
+ - *"Set cruise control to 80"*
116
+ - *"Activate sport mode"*
117
+ - *"Check battery health"*
118
+ """
119
+
120
+ FUNCTION_REFERENCE = """
121
+ ## Function Reference
122
+
123
+ ### Climate Control
124
+
125
+ | Function | Description | Required Parameters | Optional Parameters |
126
+ |----------|-------------|---------------------|---------------------|
127
+ | `set_temperature` | Set temperature in a zone | `temperature` (1-80) | `area` (driver/front-passenger/rear-right/rear-left), `unit` (Celsius/Fahrenheit) |
128
+ | `adjust_temperature` | Increase/decrease temperature | `action` (increase/decrease) | `area` |
129
+ | `set_fan_speed` | Set fan to specific level | `speed` (LOW/MEDIUM/HIGH) | `area` |
130
+ | `adjust_fan_speed` | Increase/decrease fan speed | `speed` (increase/decrease) | `area` |
131
+
132
+ ### Comfort
133
+
134
+ | Function | Description | Required Parameters | Optional Parameters |
135
+ |----------|-------------|---------------------|---------------------|
136
+ | `adjust_seat` | Adjust seat position | `position` (forward/backward/up/down/tilt-forward/tilt-backward) | `seat_type` (driver/front-passenger/rear_right/rear_left) |
137
+ | `control_window` | Open/close windows | `window_position` (open/close) | `window_location` (driver/front-passenger/rear_right/rear_left) |
138
+
139
+ ### Wipers & Defroster
140
+
141
+ | Function | Description | Required Parameters | Optional Parameters |
142
+ |----------|-------------|---------------------|---------------------|
143
+ | `set_wiper_speed` | Set wiper speed | `speed` (LOW/MEDIUM/HIGH) | - |
144
+ | `adjust_wiper_speed` | Increase/decrease wipers | `speed` (INCREASE/DECREASE) | - |
145
+ | `activate_defroster` | Activate window defroster | - | `defroster_zone` (front/rear/all), `duration_minutes` (1-30) |
146
+
147
+ ### Engine & Security
148
+
149
+ | Function | Description | Required Parameters | Optional Parameters |
150
+ |----------|-------------|---------------------|---------------------|
151
+ | `start_engine` | Start the car's engine | - | `method` (remote/keyless/keyed) |
152
+ | `lock_doors` | Lock/unlock car doors | `lock_state` (lock/unlock) | - |
153
+
154
+ ### Entertainment
155
+
156
+ | Function | Description | Required Parameters | Optional Parameters |
157
+ |----------|-------------|---------------------|---------------------|
158
+ | `play_music` | Control music player | - | `track` (song name), `volume` (1-10) |
159
+
160
+ ### Navigation
161
+
162
+ | Function | Description | Required Parameters | Optional Parameters |
163
+ |----------|-------------|---------------------|---------------------|
164
+ | `set_navigation_destination` | Set GPS destination | `destination` (address/location) | - |
165
+
166
+ ### Lighting
167
+
168
+ | Function | Description | Required Parameters | Optional Parameters |
169
+ |----------|-------------|---------------------|---------------------|
170
+ | `toggle_headlights` | Turn headlights on/off | `light_state` (on/off) | - |
171
+ | `control_ambient_lighting` | Set interior lighting | `color` (warm/red/blue/dark/white) | `intensity` (1-10) |
172
+
173
+ ### Driving
174
+
175
+ | Function | Description | Required Parameters | Optional Parameters |
176
+ |----------|-------------|---------------------|---------------------|
177
+ | `set_cruise_control` | Set cruise control speed | `speed` (10-150 km/h) | - |
178
+ | `toggle_sport_mode` | Activate/deactivate sport mode | `action` (activate/deactivate) | - |
179
+
180
+ ### Maintenance
181
+
182
+ | Function | Description | Required Parameters | Optional Parameters |
183
+ |----------|-------------|---------------------|---------------------|
184
+ | `check_battery_health` | Check battery status | - | `include_history` (true/false) |
185
+ """
186
+
187
+
188
+ # Build Gradio interface
189
+ with gr.Blocks(
190
+ title="SLM Function Calling - Car Control",
191
+ ) as demo:
192
+ with gr.Row():
193
+ # Left column: Description
194
+ with gr.Column(scale=1):
195
+ gr.Markdown(LEFT_DESCRIPTION)
196
+
197
+ # Right column: Demo
198
+ with gr.Column(scale=1):
199
+ gr.Markdown("## Try It Out")
200
+ command_input = gr.Textbox(
201
+ label="Your Command",
202
+ placeholder="e.g., Set the temperature to 22 degrees",
203
+ lines=2,
204
+ )
205
+ predict_btn = gr.Button("Predict Function Call", variant="primary")
206
+
207
+ raw_output = gr.Textbox(
208
+ label="Raw Model Output",
209
+ lines=3,
210
+ interactive=False,
211
+ )
212
+ parsed_output = gr.JSON(
213
+ label="Parsed Function Call",
214
+ )
215
+
216
+ gr.Examples(
217
+ examples=EXAMPLE_COMMANDS,
218
+ inputs=[command_input],
219
+ label="Example Commands",
220
+ )
221
+
222
+ # Event handlers
223
+ predict_btn.click(
224
+ fn=predict_function_call,
225
+ inputs=[command_input],
226
+ outputs=[raw_output, parsed_output],
227
+ )
228
+
229
+ command_input.submit(
230
+ fn=predict_function_call,
231
+ inputs=[command_input],
232
+ outputs=[raw_output, parsed_output],
233
+ )
234
+
235
+ # Function reference accordion
236
+ gr.Markdown("---")
237
+
238
+ with gr.Accordion("Function Reference (All 18 Functions)", open=False):
239
+ gr.Markdown(FUNCTION_REFERENCE)
240
+
241
+ gr.Markdown(
242
+ """
243
+ ---
244
+ **Source Code:** [GitHub Repository](https://github.com/suyash94/slm-function-calling)
245
+
246
+ **Limitations:**
247
+ - Only trained on car control domain commands
248
+ - May produce incorrect outputs for ambiguous or out-of-domain queries
249
+ - Best results with clear, specific commands
250
+ """
251
+ )
252
+
253
+
254
+ if __name__ == "__main__":
255
+ # share=True creates a public URL (works for 72 hours)
256
+ demo.launch(share=True)
inferencer.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-contained inference for SLM Function Calling on HuggingFace Spaces."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import torch
11
+ from huggingface_hub import snapshot_download
12
+ from peft import PeftModel
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
14
+
15
+
16
+ # System prompt for function calling
17
+ SYSTEM_PROMPT = (
18
+ "You are a helpful assistant. You have to either provide a way to answer "
19
+ "user's request or answer user's query."
20
+ )
21
+
22
+
23
+ def parse_function_call(response: str) -> dict[str, Any]:
24
+ """Parse model output to extract function call.
25
+
26
+ Parses the model's response in the format:
27
+ <functioncall> {"name": "...", "arguments": "..."} <|im_end|>
28
+
29
+ :param response: Raw model output string
30
+ :return: Dict with 'fn_name' and 'properties' keys, or 'error' key if parsing fails
31
+ """
32
+ try:
33
+ # Define delimiters
34
+ start_delim = "<functioncall> "
35
+ end_delim = "<|im_end|>"
36
+
37
+ # Find the JSON portion between delimiters
38
+ start_idx = response.find(start_delim)
39
+ if start_idx == -1:
40
+ return {"error": "Start delimiter '<functioncall> ' not found"}
41
+
42
+ start_idx += len(start_delim)
43
+ end_idx = response.find(end_delim, start_idx)
44
+
45
+ if end_idx == -1:
46
+ return {"error": "End delimiter '<|im_end|>' not found"}
47
+
48
+ # Extract the JSON string
49
+ json_str = response[start_idx:end_idx].strip()
50
+
51
+ # Parse the outer JSON (contains name and arguments)
52
+ function_call_dict = json.loads(json_str)
53
+
54
+ # Extract function name and arguments
55
+ fn_name = function_call_dict.get("name")
56
+ if fn_name is None:
57
+ return {"error": "Function name not found in response"}
58
+
59
+ arguments_str = function_call_dict.get("arguments", "{}")
60
+
61
+ # Handle arguments - convert Python-style to JSON-style
62
+ if isinstance(arguments_str, str):
63
+ # Replace Python boolean/None syntax with JSON syntax
64
+ arguments_str = arguments_str.replace("'", '"')
65
+ arguments_str = arguments_str.replace("True", "true")
66
+ arguments_str = arguments_str.replace("False", "false")
67
+ arguments_str = arguments_str.replace("None", "null")
68
+
69
+ properties = json.loads(arguments_str)
70
+ elif isinstance(arguments_str, dict):
71
+ properties = arguments_str
72
+ else:
73
+ properties = {}
74
+
75
+ return {"fn_name": fn_name, "properties": properties}
76
+
77
+ except json.JSONDecodeError as e:
78
+ return {"error": f"JSON parsing error: {e}"}
79
+ except Exception as e:
80
+ return {"error": str(e)}
81
+
82
+
83
+ class Inferencer:
84
+ """Inference class for SLM Function Calling model.
85
+
86
+ Downloads LoRA adapter from HuggingFace Hub on initialization,
87
+ or loads from a local directory if specified.
88
+
89
+ Configuration via environment variables:
90
+ - HF_MODEL_REPO: HuggingFace Hub repo ID (e.g., 'username/gpt2-fc-adapter')
91
+ - LOCAL_CHECKPOINT_DIR: Local directory path (overrides HF_MODEL_REPO)
92
+ - BASE_MODEL: Base model name (default: 'gpt2')
93
+
94
+ Example::
95
+
96
+ # Set environment variable
97
+ os.environ["HF_MODEL_REPO"] = "suyash94/gpt2-fc-adapter"
98
+
99
+ inferencer = Inferencer()
100
+ result = inferencer.predict("Set the temperature to 22 degrees")
101
+ print(result["parsed"]) # {"fn_name": "set_temperature", "properties": {...}}
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ repo_id: str | None = None,
107
+ local_dir: str | Path | None = None,
108
+ base_model: str | None = None,
109
+ device: torch.device | str | None = None,
110
+ cache_dir: str | None = None,
111
+ ) -> None:
112
+ """Initialize the inferencer.
113
+
114
+ :param repo_id: HuggingFace Hub repo ID for LoRA adapter
115
+ :param local_dir: Local directory containing adapter files
116
+ :param base_model: Base model name (default: gpt2)
117
+ :param device: Device for inference (auto-detected if None)
118
+ :param cache_dir: Cache directory for downloaded files
119
+ """
120
+ # Configuration from params or environment
121
+ self.local_dir = local_dir or os.environ.get("LOCAL_CHECKPOINT_DIR")
122
+ self.repo_id = repo_id or os.environ.get("HF_MODEL_REPO", "suyash94/gpt2-fc-adapter")
123
+ self.base_model = base_model or os.environ.get("BASE_MODEL", "gpt2")
124
+
125
+ if self.local_dir:
126
+ self.local_dir = Path(self.local_dir)
127
+
128
+ # Set device
129
+ if device is None:
130
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+ else:
132
+ self.device = torch.device(device) if isinstance(device, str) else device
133
+
134
+ self._model: torch.nn.Module | None = None
135
+ self._tokenizer: PreTrainedTokenizer | None = None
136
+
137
+ # Load model and tokenizer
138
+ self._load_model(cache_dir)
139
+
140
+ def _load_model(self, cache_dir: str | None = None) -> None:
141
+ """Load base model, tokenizer, and LoRA adapter.
142
+
143
+ :param cache_dir: Cache directory for HuggingFace downloads
144
+ """
145
+ # Get adapter path (local or download from Hub)
146
+ if self.local_dir:
147
+ print(f"Loading adapter from local: {self.local_dir}")
148
+ adapter_path = self.local_dir
149
+ else:
150
+ print(f"Downloading adapter from {self.repo_id}...")
151
+ adapter_path = Path(
152
+ snapshot_download(
153
+ repo_id=self.repo_id,
154
+ cache_dir=cache_dir,
155
+ )
156
+ )
157
+
158
+ # Load tokenizer from adapter (includes special tokens)
159
+ print(f"Loading tokenizer from adapter...")
160
+ self._tokenizer = AutoTokenizer.from_pretrained(
161
+ adapter_path,
162
+ trust_remote_code=True,
163
+ )
164
+
165
+ # Ensure pad token is set
166
+ if self._tokenizer.pad_token is None:
167
+ self._tokenizer.pad_token = self._tokenizer.eos_token
168
+
169
+ # Load base model
170
+ print(f"Loading base model: {self.base_model}...")
171
+ base_model = AutoModelForCausalLM.from_pretrained(
172
+ self.base_model,
173
+ torch_dtype=torch.float32, # CPU-friendly
174
+ trust_remote_code=True,
175
+ )
176
+
177
+ # Resize embeddings if tokenizer has more tokens than model
178
+ if len(self._tokenizer) > base_model.get_input_embeddings().num_embeddings:
179
+ print(f"Resizing embeddings: {base_model.get_input_embeddings().num_embeddings} -> {len(self._tokenizer)}")
180
+ base_model.resize_token_embeddings(len(self._tokenizer))
181
+
182
+ # Load LoRA adapter
183
+ print(f"Loading LoRA adapter...")
184
+ self._model = PeftModel.from_pretrained(
185
+ base_model,
186
+ adapter_path,
187
+ )
188
+
189
+ # Move to device and set eval mode
190
+ self._model.to(self.device)
191
+ self._model.eval()
192
+
193
+ print(f"Model loaded on device: {self.device}")
194
+
195
+ def predict(self, user_query: str, max_new_tokens: int = 128) -> dict[str, Any]:
196
+ """Generate a function call prediction for a user query.
197
+
198
+ :param user_query: User's natural language command
199
+ :param max_new_tokens: Maximum new tokens to generate
200
+ :return: Dict with 'response' and 'parsed' (function call info)
201
+ """
202
+ if self._model is None or self._tokenizer is None:
203
+ raise RuntimeError("Model not loaded")
204
+
205
+ # Format as chat
206
+ messages = [
207
+ {"role": "system", "content": SYSTEM_PROMPT},
208
+ {"role": "user", "content": user_query},
209
+ ]
210
+
211
+ # Apply chat template
212
+ input_text = self._tokenizer.apply_chat_template(messages, tokenize=False)
213
+
214
+ # Tokenize
215
+ inputs = self._tokenizer(input_text, return_tensors="pt")
216
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
217
+
218
+ # Generate
219
+ with torch.no_grad():
220
+ outputs = self._model.generate(
221
+ **inputs,
222
+ max_new_tokens=max_new_tokens,
223
+ pad_token_id=self._tokenizer.pad_token_id,
224
+ eos_token_id=self._tokenizer.eos_token_id,
225
+ do_sample=False, # Deterministic
226
+ )
227
+
228
+ # Decode response (only the generated part)
229
+ full_response = self._tokenizer.decode(outputs[0], skip_special_tokens=False)
230
+ response = full_response[len(input_text):]
231
+
232
+ # Parse function call
233
+ parsed = parse_function_call(response)
234
+
235
+ return {
236
+ "response": response,
237
+ "parsed": parsed,
238
+ }
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML
2
+ torch>=2.1.0
3
+ transformers>=4.46.0
4
+
5
+ # LoRA adapter loading
6
+ peft>=0.13.0
7
+
8
+ # Gradio UI
9
+ gradio>=4.0.0
10
+
11
+ # HuggingFace Hub for model download
12
+ huggingface_hub>=0.20.0