Stop distributed inference on EOS
Browse files
distributed/inference/agillm35_distributed_infer.py
CHANGED
|
@@ -576,11 +576,13 @@ def cmd_infer(args: argparse.Namespace) -> None:
|
|
| 576 |
prompt_tokens = runtime.tok.encode(args.prompt)
|
| 577 |
if not prompt_tokens:
|
| 578 |
prompt_tokens = [runtime.EOS]
|
| 579 |
-
ids = torch.tensor([prompt_tokens], dtype=torch.long)
|
| 580 |
-
prompt_len = ids.size(1)
|
| 581 |
-
stage_stats: list[dict[str, Any]] = []
|
| 582 |
-
session_id = args.session_id or f"agillm35-{uuid.uuid4().hex}"
|
| 583 |
-
|
|
|
|
|
|
|
| 584 |
with torch.no_grad():
|
| 585 |
if args.cache_mode == "kv":
|
| 586 |
hidden = emb(ids.to(args.device)).detach().cpu()
|
|
@@ -595,12 +597,15 @@ def cmd_infer(args: argparse.Namespace) -> None:
|
|
| 595 |
)
|
| 596 |
stage_stats.extend(stats)
|
| 597 |
for step in range(int(args.max_new)):
|
| 598 |
-
h = ln(hidden.to(args.device))
|
| 599 |
-
nxt = sample_next(runtime, ar_h, h, ids, args)
|
| 600 |
-
ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
|
|
|
|
|
|
| 604 |
hidden, stats = run_stage_pipeline(
|
| 605 |
stages,
|
| 606 |
hidden,
|
|
@@ -616,9 +621,12 @@ def cmd_infer(args: argparse.Namespace) -> None:
|
|
| 616 |
hidden = emb(ids.to(args.device)).detach().cpu()
|
| 617 |
hidden, stats = run_stage_pipeline(stages, hidden, args, use_cache=False)
|
| 618 |
stage_stats.extend(stats)
|
| 619 |
-
h = ln(hidden.to(args.device))
|
| 620 |
-
nxt = sample_next(runtime, ar_h, h, ids, args)
|
| 621 |
-
ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
|
|
|
|
|
|
|
|
|
|
| 622 |
elapsed = time.time() - start
|
| 623 |
all_ids = ids[0].tolist()
|
| 624 |
prompt = runtime.tok.decode(all_ids[:prompt_len], skip_special_tokens=True)
|
|
@@ -631,13 +639,13 @@ def cmd_infer(args: argparse.Namespace) -> None:
|
|
| 631 |
item["wall_sec"] += float(stat.get("wall_sec", stat.get("sec", 0.0)))
|
| 632 |
result = {
|
| 633 |
"event": "distributed_infer_done",
|
| 634 |
-
"mode": args.mode,
|
| 635 |
-
"cache_mode": args.cache_mode,
|
| 636 |
-
"session_id": session_id if args.cache_mode == "kv" else None,
|
| 637 |
-
"tokens":
|
| 638 |
-
"elapsed_sec": round(elapsed, 3),
|
| 639 |
-
"tok_per_sec": round(
|
| 640 |
-
"stages": by_stage,
|
| 641 |
}
|
| 642 |
if args.json:
|
| 643 |
result["prompt"] = prompt
|
|
|
|
| 576 |
prompt_tokens = runtime.tok.encode(args.prompt)
|
| 577 |
if not prompt_tokens:
|
| 578 |
prompt_tokens = [runtime.EOS]
|
| 579 |
+
ids = torch.tensor([prompt_tokens], dtype=torch.long)
|
| 580 |
+
prompt_len = ids.size(1)
|
| 581 |
+
stage_stats: list[dict[str, Any]] = []
|
| 582 |
+
session_id = args.session_id or f"agillm35-{uuid.uuid4().hex}"
|
| 583 |
+
eos_id = getattr(runtime, "EOS", None)
|
| 584 |
+
generated_tokens = 0
|
| 585 |
+
start = time.time()
|
| 586 |
with torch.no_grad():
|
| 587 |
if args.cache_mode == "kv":
|
| 588 |
hidden = emb(ids.to(args.device)).detach().cpu()
|
|
|
|
| 597 |
)
|
| 598 |
stage_stats.extend(stats)
|
| 599 |
for step in range(int(args.max_new)):
|
| 600 |
+
h = ln(hidden.to(args.device))
|
| 601 |
+
nxt = sample_next(runtime, ar_h, h, ids, args)
|
| 602 |
+
ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
|
| 603 |
+
generated_tokens += 1
|
| 604 |
+
if eos_id is not None and int(nxt.reshape(-1)[0].item()) == int(eos_id):
|
| 605 |
+
break
|
| 606 |
+
if step + 1 >= int(args.max_new):
|
| 607 |
+
break
|
| 608 |
+
hidden = emb(nxt.to(args.device)).detach().cpu()
|
| 609 |
hidden, stats = run_stage_pipeline(
|
| 610 |
stages,
|
| 611 |
hidden,
|
|
|
|
| 621 |
hidden = emb(ids.to(args.device)).detach().cpu()
|
| 622 |
hidden, stats = run_stage_pipeline(stages, hidden, args, use_cache=False)
|
| 623 |
stage_stats.extend(stats)
|
| 624 |
+
h = ln(hidden.to(args.device))
|
| 625 |
+
nxt = sample_next(runtime, ar_h, h, ids, args)
|
| 626 |
+
ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
|
| 627 |
+
generated_tokens += 1
|
| 628 |
+
if eos_id is not None and int(nxt.reshape(-1)[0].item()) == int(eos_id):
|
| 629 |
+
break
|
| 630 |
elapsed = time.time() - start
|
| 631 |
all_ids = ids[0].tolist()
|
| 632 |
prompt = runtime.tok.decode(all_ids[:prompt_len], skip_special_tokens=True)
|
|
|
|
| 639 |
item["wall_sec"] += float(stat.get("wall_sec", stat.get("sec", 0.0)))
|
| 640 |
result = {
|
| 641 |
"event": "distributed_infer_done",
|
| 642 |
+
"mode": args.mode,
|
| 643 |
+
"cache_mode": args.cache_mode,
|
| 644 |
+
"session_id": session_id if args.cache_mode == "kv" else None,
|
| 645 |
+
"tokens": generated_tokens,
|
| 646 |
+
"elapsed_sec": round(elapsed, 3),
|
| 647 |
+
"tok_per_sec": round(generated_tokens / max(elapsed, 1e-9), 3),
|
| 648 |
+
"stages": by_stage,
|
| 649 |
}
|
| 650 |
if args.json:
|
| 651 |
result["prompt"] = prompt
|