gyung commited on
Commit
ce3d00b
·
verified ·
1 Parent(s): 06bafc4

Upload demo_linux_fc.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo_linux_fc.py +594 -0
demo_linux_fc.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HybriKo-117M Linux Function Calling Demo
4
+ Usage: python scripts/demo_linux_fc.py
5
+
6
+ Example:
7
+ [사용자] 현재 폴더의 파일 목록을 보여줘
8
+ [HybriKo] Thought: 디렉토리 내용을 확인합니다.
9
+ Action: ls_command
10
+ Action Input: {"path": ".", "options": "-l"}
11
+ """
12
+
13
+ import torch
14
+ import sentencepiece as spm
15
+ import sys
16
+ import json
17
+ import re
18
+ import argparse
19
+
20
+ sys.path.insert(0, ".")
21
+ from hybridko.model import HybriKoModel, HybriKoConfig
22
+
23
+
24
+ # Exact system prompt used during training (with prettified JSON)
25
+ SYSTEM_PROMPT = """You are a Linux command assistant. You can use many tools (functions) to help users with their Linux tasks.
26
+ At each step, you need to give your thought to analyze the status now and what to do next, with a function call to actually execute your step. Your output should follow this format:
27
+ Thought:
28
+ Action
29
+ Action Input:
30
+
31
+ After the call, you will get the call result, and you are now in a new state.
32
+ Then you will analyze your status now, then decide what to do next...
33
+ After many (Thought-call) pairs, you finally perform the task, then you can give your final answer.
34
+
35
+ Remember:
36
+ 1. The state change is irreversible, you can't go back to one of the former state.
37
+ 2. All the thought is short, at most in 5 sentences.
38
+ 3. ALWAYS call "Finish" function at the end of the task.
39
+ 4. If you cannot handle the task with the available tools, say you don't know and call Finish with give_answer.
40
+
41
+ You have access of the following tools:
42
+ [
43
+ {
44
+ "name": "ls_command",
45
+ "description": "List directory contents.",
46
+ "parameters": {
47
+ "type": "object",
48
+ "properties": {
49
+ "path": {
50
+ "type": "string"
51
+ },
52
+ "options": {
53
+ "type": "string"
54
+ }
55
+ },
56
+ "required": [
57
+ "path"
58
+ ]
59
+ }
60
+ },
61
+ {
62
+ "name": "cd_command",
63
+ "description": "Change the current working directory.",
64
+ "parameters": {
65
+ "type": "object",
66
+ "properties": {
67
+ "path": {
68
+ "type": "string"
69
+ }
70
+ },
71
+ "required": [
72
+ "path"
73
+ ]
74
+ }
75
+ },
76
+ {
77
+ "name": "mkdir_command",
78
+ "description": "Create a new directory.",
79
+ "parameters": {
80
+ "type": "object",
81
+ "properties": {
82
+ "path": {
83
+ "type": "string"
84
+ }
85
+ },
86
+ "required": [
87
+ "path"
88
+ ]
89
+ }
90
+ },
91
+ {
92
+ "name": "rm_command",
93
+ "description": "Remove files or directories.",
94
+ "parameters": {
95
+ "type": "object",
96
+ "properties": {
97
+ "path": {
98
+ "type": "string"
99
+ },
100
+ "recursive": {
101
+ "type": "boolean"
102
+ }
103
+ },
104
+ "required": [
105
+ "path"
106
+ ]
107
+ }
108
+ },
109
+ {
110
+ "name": "cp_command",
111
+ "description": "Copy files or directories.",
112
+ "parameters": {
113
+ "type": "object",
114
+ "properties": {
115
+ "source": {
116
+ "type": "string"
117
+ },
118
+ "destination": {
119
+ "type": "string"
120
+ }
121
+ },
122
+ "required": [
123
+ "source",
124
+ "destination"
125
+ ]
126
+ }
127
+ },
128
+ {
129
+ "name": "mv_command",
130
+ "description": "Move or rename files.",
131
+ "parameters": {
132
+ "type": "object",
133
+ "properties": {
134
+ "source": {
135
+ "type": "string"
136
+ },
137
+ "destination": {
138
+ "type": "string"
139
+ }
140
+ },
141
+ "required": [
142
+ "source",
143
+ "destination"
144
+ ]
145
+ }
146
+ },
147
+ {
148
+ "name": "find_command",
149
+ "description": "Find files by name pattern.",
150
+ "parameters": {
151
+ "type": "object",
152
+ "properties": {
153
+ "path": {
154
+ "type": "string"
155
+ },
156
+ "name": {
157
+ "type": "string"
158
+ }
159
+ },
160
+ "required": [
161
+ "path",
162
+ "name"
163
+ ]
164
+ }
165
+ },
166
+ {
167
+ "name": "cat_command",
168
+ "description": "Display file contents.",
169
+ "parameters": {
170
+ "type": "object",
171
+ "properties": {
172
+ "file": {
173
+ "type": "string"
174
+ }
175
+ },
176
+ "required": [
177
+ "file"
178
+ ]
179
+ }
180
+ },
181
+ {
182
+ "name": "grep_command",
183
+ "description": "Search for patterns in files.",
184
+ "parameters": {
185
+ "type": "object",
186
+ "properties": {
187
+ "pattern": {
188
+ "type": "string"
189
+ },
190
+ "file": {
191
+ "type": "string"
192
+ }
193
+ },
194
+ "required": [
195
+ "pattern",
196
+ "file"
197
+ ]
198
+ }
199
+ },
200
+ {
201
+ "name": "head_command",
202
+ "description": "Display first lines of a file.",
203
+ "parameters": {
204
+ "type": "object",
205
+ "properties": {
206
+ "file": {
207
+ "type": "string"
208
+ },
209
+ "lines": {
210
+ "type": "integer"
211
+ }
212
+ },
213
+ "required": [
214
+ "file"
215
+ ]
216
+ }
217
+ },
218
+ {
219
+ "name": "tail_command",
220
+ "description": "Display last lines of a file.",
221
+ "parameters": {
222
+ "type": "object",
223
+ "properties": {
224
+ "file": {
225
+ "type": "string"
226
+ },
227
+ "lines": {
228
+ "type": "integer"
229
+ }
230
+ },
231
+ "required": [
232
+ "file"
233
+ ]
234
+ }
235
+ },
236
+ {
237
+ "name": "wc_command",
238
+ "description": "Count lines, words, and bytes.",
239
+ "parameters": {
240
+ "type": "object",
241
+ "properties": {
242
+ "file": {
243
+ "type": "string"
244
+ }
245
+ },
246
+ "required": [
247
+ "file"
248
+ ]
249
+ }
250
+ },
251
+ {
252
+ "name": "ps_command",
253
+ "description": "Display running processes.",
254
+ "parameters": {
255
+ "type": "object",
256
+ "properties": {
257
+ "options": {
258
+ "type": "string"
259
+ }
260
+ },
261
+ "required": []
262
+ }
263
+ },
264
+ {
265
+ "name": "df_command",
266
+ "description": "Display disk space usage.",
267
+ "parameters": {
268
+ "type": "object",
269
+ "properties": {
270
+ "options": {
271
+ "type": "string"
272
+ }
273
+ },
274
+ "required": []
275
+ }
276
+ },
277
+ {
278
+ "name": "du_command",
279
+ "description": "Display directory space usage.",
280
+ "parameters": {
281
+ "type": "object",
282
+ "properties": {
283
+ "path": {
284
+ "type": "string"
285
+ },
286
+ "options": {
287
+ "type": "string"
288
+ }
289
+ },
290
+ "required": [
291
+ "path"
292
+ ]
293
+ }
294
+ },
295
+ {
296
+ "name": "top_command",
297
+ "description": "Display system processes in real-time.",
298
+ "parameters": {
299
+ "type": "object",
300
+ "properties": {},
301
+ "required": []
302
+ }
303
+ },
304
+ {
305
+ "name": "ping_command",
306
+ "description": "Test network connectivity.",
307
+ "parameters": {
308
+ "type": "object",
309
+ "properties": {
310
+ "host": {
311
+ "type": "string"
312
+ },
313
+ "count": {
314
+ "type": "integer"
315
+ }
316
+ },
317
+ "required": [
318
+ "host"
319
+ ]
320
+ }
321
+ },
322
+ {
323
+ "name": "curl_command",
324
+ "description": "Transfer data from URL.",
325
+ "parameters": {
326
+ "type": "object",
327
+ "properties": {
328
+ "url": {
329
+ "type": "string"
330
+ },
331
+ "options": {
332
+ "type": "string"
333
+ }
334
+ },
335
+ "required": [
336
+ "url"
337
+ ]
338
+ }
339
+ },
340
+ {
341
+ "name": "chmod_command",
342
+ "description": "Change file permissions.",
343
+ "parameters": {
344
+ "type": "object",
345
+ "properties": {
346
+ "mode": {
347
+ "type": "string"
348
+ },
349
+ "file": {
350
+ "type": "string"
351
+ }
352
+ },
353
+ "required": [
354
+ "mode",
355
+ "file"
356
+ ]
357
+ }
358
+ },
359
+ {
360
+ "name": "tar_command",
361
+ "description": "Archive or extract files.",
362
+ "parameters": {
363
+ "type": "object",
364
+ "properties": {
365
+ "options": {
366
+ "type": "string"
367
+ },
368
+ "archive": {
369
+ "type": "string"
370
+ },
371
+ "files": {
372
+ "type": "string"
373
+ }
374
+ },
375
+ "required": [
376
+ "options",
377
+ "archive"
378
+ ]
379
+ }
380
+ },
381
+ {
382
+ "name": "Finish",
383
+ "description": "Complete the task.",
384
+ "parameters": {
385
+ "type": "object",
386
+ "properties": {
387
+ "give_answer": {
388
+ "type": "string"
389
+ }
390
+ },
391
+ "required": [
392
+ "give_answer"
393
+ ]
394
+ }
395
+ }
396
+ ]"""
397
+
398
+
399
+ def load_model(checkpoint_path="checkpoints/linux_fc_sft/checkpoint_epoch_15.pt"):
400
+ """Load model and tokenizer."""
401
+ print("Loading tokenizer...")
402
+ sp = spm.SentencePieceProcessor()
403
+ sp.Load("tokenizer/HybriKo_tok.model")
404
+
405
+ print("Loading model...")
406
+ config = HybriKoConfig(
407
+ d_model=768, n_layers=12, vocab_size=32000,
408
+ n_heads=12, n_kv_heads=3, ff_mult=3,
409
+ max_seq_len=6144, dropout=0.0
410
+ )
411
+ model = HybriKoModel(config)
412
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
413
+ model.load_state_dict(checkpoint["model_state_dict"])
414
+
415
+ device = "cuda" if torch.cuda.is_available() else "cpu"
416
+ model.to(device).eval()
417
+ print(f"Model loaded on {device}\n")
418
+ return model, sp, device
419
+
420
+
421
+ def generate(model, tokenizer, prompt, device, max_new_tokens=150):
422
+ """Generate response with improved stopping."""
423
+ input_ids = tokenizer.EncodeAsIds(prompt)
424
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
425
+ prompt_len = len(input_ids)
426
+
427
+ # Get EOS token ID
428
+ eos_id = tokenizer.PieceToId("<|im_end|>")
429
+ if eos_id == tokenizer.unk_id():
430
+ eos_id = None
431
+
432
+ with torch.no_grad():
433
+ generated = input_tensor
434
+ for _ in range(max_new_tokens):
435
+ outputs = model(generated)
436
+ logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
437
+ next_token_logits = logits[:, -1, :]
438
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
439
+ generated = torch.cat([generated, next_token], dim=1)
440
+
441
+ # Stop on EOS token
442
+ if eos_id and next_token.item() == eos_id:
443
+ break
444
+
445
+ # Stop when we have complete Action Input
446
+ new_tokens = generated[0, prompt_len:].tolist()
447
+ new_text = tokenizer.DecodeIds(new_tokens)
448
+
449
+ # Check for completion patterns
450
+ if "Action Input:" in new_text:
451
+ # Find the JSON part after Action Input
452
+ ai_idx = new_text.find("Action Input:")
453
+ after_ai = new_text[ai_idx + 13:].strip()
454
+ # Stop when JSON is complete (matching braces)
455
+ if after_ai.startswith("{"):
456
+ brace_count = 0
457
+ for i, c in enumerate(after_ai):
458
+ if c == "{":
459
+ brace_count += 1
460
+ elif c == "}":
461
+ brace_count -= 1
462
+ if brace_count == 0:
463
+ # Found complete JSON, stop
464
+ break_idx = ai_idx + 13 + i + 1
465
+ # Truncate to just the complete response
466
+ new_tokens = generated[0, prompt_len:].tolist()
467
+ return tokenizer.DecodeIds(new_tokens)[:break_idx + len(new_text) - len(new_text[ai_idx + 13:])]
468
+
469
+ new_tokens = generated[0, prompt_len:].tolist()
470
+ return tokenizer.DecodeIds(new_tokens)
471
+
472
+
473
+ def create_prompt(user_input):
474
+ """Create ChatML format prompt."""
475
+ return f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n"
476
+
477
+
478
+ def parse_response(response):
479
+ """Parse response into components."""
480
+ # Clean up the response - stop at <|im_end|> or garbage
481
+ if "<|im_end|>" in response:
482
+ response = response.split("<|im_end|>")[0]
483
+
484
+ # Also stop at <|im_start|> which indicates model continuing incorrectly
485
+ if "<|im_start|>" in response:
486
+ response = response.split("<|im_start|>")[0]
487
+
488
+ result = {"thought": None, "action": None, "action_input": None, "raw": response}
489
+
490
+ # Extract Thought
491
+ thought_match = re.search(r"Thought:\s*(.+?)(?=\s*Action:|\s*$)", response, re.DOTALL)
492
+ if thought_match:
493
+ result["thought"] = thought_match.group(1).strip()
494
+
495
+ # Extract Action
496
+ action_match = re.search(r"Action:\s*(\w+)", response)
497
+ if action_match:
498
+ result["action"] = action_match.group(1)
499
+
500
+ # Extract Action Input
501
+ input_match = re.search(r"Action Input:\s*(\{[^}]+\})", response, re.DOTALL)
502
+ if input_match:
503
+ try:
504
+ result["action_input"] = json.loads(input_match.group(1))
505
+ except:
506
+ result["action_input"] = input_match.group(1)
507
+
508
+ return result
509
+
510
+
511
+ def run_single(model, tokenizer, device, user_input):
512
+ """Run single inference."""
513
+ prompt = create_prompt(user_input)
514
+ response = generate(model, tokenizer, prompt, device)
515
+ return parse_response(response)
516
+
517
+
518
+ def main():
519
+ import locale
520
+ import io
521
+
522
+ # Set UTF-8 encoding for stdin/stdout
523
+ if sys.stdin.encoding != 'utf-8':
524
+ sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8', errors='replace')
525
+ if sys.stdout.encoding != 'utf-8':
526
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
527
+
528
+ parser = argparse.ArgumentParser(description="HybriKo Linux FC Demo")
529
+ parser.add_argument("--checkpoint", default="checkpoints/linux_fc_sft/checkpoint_epoch_15.pt")
530
+ parser.add_argument("--query", type=str, help="Single query mode (non-interactive)")
531
+ args = parser.parse_args()
532
+
533
+ print("=" * 60)
534
+ print(" HybriKo-117M Linux Function Calling Demo")
535
+ print("=" * 60)
536
+
537
+ model, tokenizer, device = load_model(args.checkpoint)
538
+
539
+ # Single query mode
540
+ if args.query:
541
+ result = run_single(model, tokenizer, device, args.query)
542
+ print(f"Input: {args.query}")
543
+ print("-" * 40)
544
+ if result["thought"]:
545
+ print(f"Thought: {result['thought']}")
546
+ if result["action"]:
547
+ print(f"Action: {result['action']}")
548
+ if result["action_input"]:
549
+ print(f"Input: {json.dumps(result['action_input'], ensure_ascii=False)}")
550
+ return
551
+
552
+ # Interactive mode
553
+ print("Supported commands:")
554
+ print(" ls, cd, mkdir, rm, cp, mv, find, cat, grep, head,")
555
+ print(" tail, wc, ps, df, du, top, ping, curl, chmod, tar")
556
+ print("\nType 'quit' or 'exit' to exit")
557
+ print("=" * 60)
558
+
559
+ while True:
560
+ try:
561
+ print("\n[User] ", end="", flush=True)
562
+ user_input = sys.stdin.readline()
563
+ if not user_input: # EOF
564
+ break
565
+ user_input = user_input.strip()
566
+ if not user_input:
567
+ continue
568
+ if user_input.lower() in ["quit", "exit", "q"]:
569
+ print("Goodbye!")
570
+ break
571
+
572
+ result = run_single(model, tokenizer, device, user_input)
573
+
574
+ print("\n[HybriKo]")
575
+ print("-" * 40)
576
+ if result["thought"]:
577
+ print(f"Thought: {result['thought']}")
578
+ if result["action"]:
579
+ print(f"Action: {result['action']}")
580
+ if result["action_input"]:
581
+ print(f"Input: {json.dumps(result['action_input'], ensure_ascii=False)}")
582
+ print("-" * 40)
583
+
584
+ except KeyboardInterrupt:
585
+ print("\nGoodbye!")
586
+ break
587
+ except UnicodeDecodeError as e:
588
+ print(f"\n[Error] Encoding issue: {e}")
589
+ print("Try using --query option instead: python scripts/demo_linux_fc.py --query \"your query\"")
590
+ continue
591
+
592
+
593
+ if __name__ == "__main__":
594
+ main()