OpenTransformer commited on
Commit
a18d8d3
·
verified ·
1 Parent(s): 7814939

Stop distributed inference on EOS

Browse files
distributed/inference/agillm35_distributed_infer.py CHANGED
@@ -576,11 +576,13 @@ def cmd_infer(args: argparse.Namespace) -> None:
576
  prompt_tokens = runtime.tok.encode(args.prompt)
577
  if not prompt_tokens:
578
  prompt_tokens = [runtime.EOS]
579
- ids = torch.tensor([prompt_tokens], dtype=torch.long)
580
- prompt_len = ids.size(1)
581
- stage_stats: list[dict[str, Any]] = []
582
- session_id = args.session_id or f"agillm35-{uuid.uuid4().hex}"
583
- start = time.time()
 
 
584
  with torch.no_grad():
585
  if args.cache_mode == "kv":
586
  hidden = emb(ids.to(args.device)).detach().cpu()
@@ -595,12 +597,15 @@ def cmd_infer(args: argparse.Namespace) -> None:
595
  )
596
  stage_stats.extend(stats)
597
  for step in range(int(args.max_new)):
598
- h = ln(hidden.to(args.device))
599
- nxt = sample_next(runtime, ar_h, h, ids, args)
600
- ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
601
- if step + 1 >= int(args.max_new):
602
- break
603
- hidden = emb(nxt.to(args.device)).detach().cpu()
 
 
 
604
  hidden, stats = run_stage_pipeline(
605
  stages,
606
  hidden,
@@ -616,9 +621,12 @@ def cmd_infer(args: argparse.Namespace) -> None:
616
  hidden = emb(ids.to(args.device)).detach().cpu()
617
  hidden, stats = run_stage_pipeline(stages, hidden, args, use_cache=False)
618
  stage_stats.extend(stats)
619
- h = ln(hidden.to(args.device))
620
- nxt = sample_next(runtime, ar_h, h, ids, args)
621
- ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
 
 
 
622
  elapsed = time.time() - start
623
  all_ids = ids[0].tolist()
624
  prompt = runtime.tok.decode(all_ids[:prompt_len], skip_special_tokens=True)
@@ -631,13 +639,13 @@ def cmd_infer(args: argparse.Namespace) -> None:
631
  item["wall_sec"] += float(stat.get("wall_sec", stat.get("sec", 0.0)))
632
  result = {
633
  "event": "distributed_infer_done",
634
- "mode": args.mode,
635
- "cache_mode": args.cache_mode,
636
- "session_id": session_id if args.cache_mode == "kv" else None,
637
- "tokens": int(args.max_new),
638
- "elapsed_sec": round(elapsed, 3),
639
- "tok_per_sec": round(int(args.max_new) / max(elapsed, 1e-9), 3),
640
- "stages": by_stage,
641
  }
642
  if args.json:
643
  result["prompt"] = prompt
 
576
  prompt_tokens = runtime.tok.encode(args.prompt)
577
  if not prompt_tokens:
578
  prompt_tokens = [runtime.EOS]
579
+ ids = torch.tensor([prompt_tokens], dtype=torch.long)
580
+ prompt_len = ids.size(1)
581
+ stage_stats: list[dict[str, Any]] = []
582
+ session_id = args.session_id or f"agillm35-{uuid.uuid4().hex}"
583
+ eos_id = getattr(runtime, "EOS", None)
584
+ generated_tokens = 0
585
+ start = time.time()
586
  with torch.no_grad():
587
  if args.cache_mode == "kv":
588
  hidden = emb(ids.to(args.device)).detach().cpu()
 
597
  )
598
  stage_stats.extend(stats)
599
  for step in range(int(args.max_new)):
600
+ h = ln(hidden.to(args.device))
601
+ nxt = sample_next(runtime, ar_h, h, ids, args)
602
+ ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
603
+ generated_tokens += 1
604
+ if eos_id is not None and int(nxt.reshape(-1)[0].item()) == int(eos_id):
605
+ break
606
+ if step + 1 >= int(args.max_new):
607
+ break
608
+ hidden = emb(nxt.to(args.device)).detach().cpu()
609
  hidden, stats = run_stage_pipeline(
610
  stages,
611
  hidden,
 
621
  hidden = emb(ids.to(args.device)).detach().cpu()
622
  hidden, stats = run_stage_pipeline(stages, hidden, args, use_cache=False)
623
  stage_stats.extend(stats)
624
+ h = ln(hidden.to(args.device))
625
+ nxt = sample_next(runtime, ar_h, h, ids, args)
626
+ ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
627
+ generated_tokens += 1
628
+ if eos_id is not None and int(nxt.reshape(-1)[0].item()) == int(eos_id):
629
+ break
630
  elapsed = time.time() - start
631
  all_ids = ids[0].tolist()
632
  prompt = runtime.tok.decode(all_ids[:prompt_len], skip_special_tokens=True)
 
639
  item["wall_sec"] += float(stat.get("wall_sec", stat.get("sec", 0.0)))
640
  result = {
641
  "event": "distributed_infer_done",
642
+ "mode": args.mode,
643
+ "cache_mode": args.cache_mode,
644
+ "session_id": session_id if args.cache_mode == "kv" else None,
645
+ "tokens": generated_tokens,
646
+ "elapsed_sec": round(elapsed, 3),
647
+ "tok_per_sec": round(generated_tokens / max(elapsed, 1e-9), 3),
648
+ "stages": by_stage,
649
  }
650
  if args.json:
651
  result["prompt"] = prompt