arach commited on
Commit
b8b95e5
·
verified ·
1 Parent(s): 4cf8bba

Improve Colab runner progress visibility

Browse files
eval/news_summarization/run_hf_transformers.py CHANGED
@@ -29,7 +29,7 @@ from eval.news_summarization.run_news_summary_pilot import ( # noqa: E402
29
 
30
  def parse_args() -> argparse.Namespace:
31
  parser = argparse.ArgumentParser(description="Run the news summarization pilot on Hugging Face/Colab with transformers.")
32
- parser.add_argument("--model", required=True, help="Hub model id, e.g. Qwen/Qwen3.5-8B-Instruct-2507")
33
  parser.add_argument("--dataset", default="bbc2024_qwen_reference")
34
  parser.add_argument("--prompt-style", default="simple", choices=["simple", "helpful", "detailed"])
35
  parser.add_argument("--limit", type=int, default=50)
@@ -45,7 +45,7 @@ def parse_args() -> argparse.Namespace:
45
  parser.add_argument("--bertscore-model", default="roberta-large")
46
  parser.add_argument("--output")
47
  parser.add_argument("--resume", action="store_true")
48
- parser.add_argument("--save-every", type=int, default=5)
49
  parser.add_argument("--verbose", action="store_true")
50
  return parser.parse_args()
51
 
@@ -54,24 +54,21 @@ def build_generator(args: argparse.Namespace):
54
  import torch
55
  from transformers import pipeline
56
 
57
- torch_dtype = None
58
  if args.dtype != "auto":
59
- torch_dtype = getattr(torch, args.dtype)
60
 
61
  kwargs: dict[str, object] = {
62
  "model": args.model,
63
  "device_map": args.device_map,
64
  "trust_remote_code": args.trust_remote_code,
65
  }
66
- if torch_dtype is not None:
67
- kwargs["torch_dtype"] = torch_dtype
68
  if args.attn_implementation:
69
  kwargs["model_kwargs"] = {"attn_implementation": args.attn_implementation}
70
 
71
  generator = pipeline("text-generation", **kwargs)
72
- generation_config = getattr(generator.model, "generation_config", None)
73
- if generation_config is not None and getattr(generation_config, "max_length", None) == 20:
74
- generation_config.max_length = None
75
  return generator
76
 
77
 
@@ -113,6 +110,7 @@ def main() -> int:
113
  print(f"Resuming from {output_path} with {len(rows)} completed cases.")
114
 
115
  pending_cases = [case for case in cases if case.case_id not in completed_case_ids]
 
116
  for index, case in enumerate(pending_cases, start=len(rows) + 1):
117
  messages = build_messages(case, args.prompt_style)
118
  prompt = render_prompt(generator, messages)
@@ -139,6 +137,14 @@ def main() -> int:
139
  "provider_metadata": {"model": args.model},
140
  }
141
  rows.append(row)
 
 
 
 
 
 
 
 
142
  if args.save_every > 0 and len(rows) % args.save_every == 0:
143
  write_progress(output_path, rows, argparse.Namespace(
144
  provider="hf-transformers",
@@ -146,14 +152,11 @@ def main() -> int:
146
  dataset=args.dataset,
147
  prompt_style=args.prompt_style,
148
  ), final=False)
149
- print(f"Saved progress at {len(rows)}/{len(cases)} cases -> {output_path}")
150
  if args.verbose:
151
  print(
152
- f"[{index:03d}] {case.case_id} "
153
- f"token_f1={row['scores']['token_f1']:.4f} "
154
- f"rougeL={row['scores'].get('rougeL_f1') or 0:.4f} "
155
- f"words={row['scores']['word_count']} "
156
- f"latency_ms={row['latency_ms']:.2f}"
157
  )
158
 
159
  compute_bertscore(rows, args.disable_bertscore, args.bertscore_model)
 
29
 
30
  def parse_args() -> argparse.Namespace:
31
  parser = argparse.ArgumentParser(description="Run the news summarization pilot on Hugging Face/Colab with transformers.")
32
+ parser.add_argument("--model", required=True, help="Hub model id, e.g. Qwen/Qwen2.5-7B-Instruct")
33
  parser.add_argument("--dataset", default="bbc2024_qwen_reference")
34
  parser.add_argument("--prompt-style", default="simple", choices=["simple", "helpful", "detailed"])
35
  parser.add_argument("--limit", type=int, default=50)
 
45
  parser.add_argument("--bertscore-model", default="roberta-large")
46
  parser.add_argument("--output")
47
  parser.add_argument("--resume", action="store_true")
48
+ parser.add_argument("--save-every", type=int, default=1)
49
  parser.add_argument("--verbose", action="store_true")
50
  return parser.parse_args()
51
 
 
54
  import torch
55
  from transformers import pipeline
56
 
57
+ dtype = None
58
  if args.dtype != "auto":
59
+ dtype = getattr(torch, args.dtype)
60
 
61
  kwargs: dict[str, object] = {
62
  "model": args.model,
63
  "device_map": args.device_map,
64
  "trust_remote_code": args.trust_remote_code,
65
  }
66
+ if dtype is not None:
67
+ kwargs["dtype"] = dtype
68
  if args.attn_implementation:
69
  kwargs["model_kwargs"] = {"attn_implementation": args.attn_implementation}
70
 
71
  generator = pipeline("text-generation", **kwargs)
 
 
 
72
  return generator
73
 
74
 
 
110
  print(f"Resuming from {output_path} with {len(rows)} completed cases.")
111
 
112
  pending_cases = [case for case in cases if case.case_id not in completed_case_ids]
113
+ total_cases = len(cases)
114
  for index, case in enumerate(pending_cases, start=len(rows) + 1):
115
  messages = build_messages(case, args.prompt_style)
116
  prompt = render_prompt(generator, messages)
 
137
  "provider_metadata": {"model": args.model},
138
  }
139
  rows.append(row)
140
+ print(
141
+ f"[{index:03d}/{total_cases:03d}] {case.case_id} "
142
+ f"token_f1={row['scores']['token_f1']:.4f} "
143
+ f"rougeL={row['scores'].get('rougeL_f1') or 0:.4f} "
144
+ f"words={row['scores']['word_count']} "
145
+ f"latency_ms={row['latency_ms']:.2f}",
146
+ flush=True,
147
+ )
148
  if args.save_every > 0 and len(rows) % args.save_every == 0:
149
  write_progress(output_path, rows, argparse.Namespace(
150
  provider="hf-transformers",
 
152
  dataset=args.dataset,
153
  prompt_style=args.prompt_style,
154
  ), final=False)
155
+ print(f"Saved progress at {len(rows)}/{total_cases} cases -> {output_path}", flush=True)
156
  if args.verbose:
157
  print(
158
+ f"PRED: {prediction[:300]}",
159
+ flush=True,
 
 
 
160
  )
161
 
162
  compute_bertscore(rows, args.disable_bertscore, args.bertscore_model)