Improve Colab runner progress visibility
Browse files
eval/news_summarization/run_hf_transformers.py
CHANGED
|
@@ -29,7 +29,7 @@ from eval.news_summarization.run_news_summary_pilot import ( # noqa: E402
|
|
| 29 |
|
| 30 |
def parse_args() -> argparse.Namespace:
|
| 31 |
parser = argparse.ArgumentParser(description="Run the news summarization pilot on Hugging Face/Colab with transformers.")
|
| 32 |
-
parser.add_argument("--model", required=True, help="Hub model id, e.g. Qwen/
|
| 33 |
parser.add_argument("--dataset", default="bbc2024_qwen_reference")
|
| 34 |
parser.add_argument("--prompt-style", default="simple", choices=["simple", "helpful", "detailed"])
|
| 35 |
parser.add_argument("--limit", type=int, default=50)
|
|
@@ -45,7 +45,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 45 |
parser.add_argument("--bertscore-model", default="roberta-large")
|
| 46 |
parser.add_argument("--output")
|
| 47 |
parser.add_argument("--resume", action="store_true")
|
| 48 |
-
parser.add_argument("--save-every", type=int, default=
|
| 49 |
parser.add_argument("--verbose", action="store_true")
|
| 50 |
return parser.parse_args()
|
| 51 |
|
|
@@ -54,24 +54,21 @@ def build_generator(args: argparse.Namespace):
|
|
| 54 |
import torch
|
| 55 |
from transformers import pipeline
|
| 56 |
|
| 57 |
-
|
| 58 |
if args.dtype != "auto":
|
| 59 |
-
|
| 60 |
|
| 61 |
kwargs: dict[str, object] = {
|
| 62 |
"model": args.model,
|
| 63 |
"device_map": args.device_map,
|
| 64 |
"trust_remote_code": args.trust_remote_code,
|
| 65 |
}
|
| 66 |
-
if
|
| 67 |
-
kwargs["
|
| 68 |
if args.attn_implementation:
|
| 69 |
kwargs["model_kwargs"] = {"attn_implementation": args.attn_implementation}
|
| 70 |
|
| 71 |
generator = pipeline("text-generation", **kwargs)
|
| 72 |
-
generation_config = getattr(generator.model, "generation_config", None)
|
| 73 |
-
if generation_config is not None and getattr(generation_config, "max_length", None) == 20:
|
| 74 |
-
generation_config.max_length = None
|
| 75 |
return generator
|
| 76 |
|
| 77 |
|
|
@@ -113,6 +110,7 @@ def main() -> int:
|
|
| 113 |
print(f"Resuming from {output_path} with {len(rows)} completed cases.")
|
| 114 |
|
| 115 |
pending_cases = [case for case in cases if case.case_id not in completed_case_ids]
|
|
|
|
| 116 |
for index, case in enumerate(pending_cases, start=len(rows) + 1):
|
| 117 |
messages = build_messages(case, args.prompt_style)
|
| 118 |
prompt = render_prompt(generator, messages)
|
|
@@ -139,6 +137,14 @@ def main() -> int:
|
|
| 139 |
"provider_metadata": {"model": args.model},
|
| 140 |
}
|
| 141 |
rows.append(row)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
if args.save_every > 0 and len(rows) % args.save_every == 0:
|
| 143 |
write_progress(output_path, rows, argparse.Namespace(
|
| 144 |
provider="hf-transformers",
|
|
@@ -146,14 +152,11 @@ def main() -> int:
|
|
| 146 |
dataset=args.dataset,
|
| 147 |
prompt_style=args.prompt_style,
|
| 148 |
), final=False)
|
| 149 |
-
print(f"Saved progress at {len(rows)}/{
|
| 150 |
if args.verbose:
|
| 151 |
print(
|
| 152 |
-
f"
|
| 153 |
-
|
| 154 |
-
f"rougeL={row['scores'].get('rougeL_f1') or 0:.4f} "
|
| 155 |
-
f"words={row['scores']['word_count']} "
|
| 156 |
-
f"latency_ms={row['latency_ms']:.2f}"
|
| 157 |
)
|
| 158 |
|
| 159 |
compute_bertscore(rows, args.disable_bertscore, args.bertscore_model)
|
|
|
|
| 29 |
|
| 30 |
def parse_args() -> argparse.Namespace:
|
| 31 |
parser = argparse.ArgumentParser(description="Run the news summarization pilot on Hugging Face/Colab with transformers.")
|
| 32 |
+
parser.add_argument("--model", required=True, help="Hub model id, e.g. Qwen/Qwen2.5-7B-Instruct")
|
| 33 |
parser.add_argument("--dataset", default="bbc2024_qwen_reference")
|
| 34 |
parser.add_argument("--prompt-style", default="simple", choices=["simple", "helpful", "detailed"])
|
| 35 |
parser.add_argument("--limit", type=int, default=50)
|
|
|
|
| 45 |
parser.add_argument("--bertscore-model", default="roberta-large")
|
| 46 |
parser.add_argument("--output")
|
| 47 |
parser.add_argument("--resume", action="store_true")
|
| 48 |
+
parser.add_argument("--save-every", type=int, default=1)
|
| 49 |
parser.add_argument("--verbose", action="store_true")
|
| 50 |
return parser.parse_args()
|
| 51 |
|
|
|
|
| 54 |
import torch
|
| 55 |
from transformers import pipeline
|
| 56 |
|
| 57 |
+
dtype = None
|
| 58 |
if args.dtype != "auto":
|
| 59 |
+
dtype = getattr(torch, args.dtype)
|
| 60 |
|
| 61 |
kwargs: dict[str, object] = {
|
| 62 |
"model": args.model,
|
| 63 |
"device_map": args.device_map,
|
| 64 |
"trust_remote_code": args.trust_remote_code,
|
| 65 |
}
|
| 66 |
+
if dtype is not None:
|
| 67 |
+
kwargs["dtype"] = dtype
|
| 68 |
if args.attn_implementation:
|
| 69 |
kwargs["model_kwargs"] = {"attn_implementation": args.attn_implementation}
|
| 70 |
|
| 71 |
generator = pipeline("text-generation", **kwargs)
|
|
|
|
|
|
|
|
|
|
| 72 |
return generator
|
| 73 |
|
| 74 |
|
|
|
|
| 110 |
print(f"Resuming from {output_path} with {len(rows)} completed cases.")
|
| 111 |
|
| 112 |
pending_cases = [case for case in cases if case.case_id not in completed_case_ids]
|
| 113 |
+
total_cases = len(cases)
|
| 114 |
for index, case in enumerate(pending_cases, start=len(rows) + 1):
|
| 115 |
messages = build_messages(case, args.prompt_style)
|
| 116 |
prompt = render_prompt(generator, messages)
|
|
|
|
| 137 |
"provider_metadata": {"model": args.model},
|
| 138 |
}
|
| 139 |
rows.append(row)
|
| 140 |
+
print(
|
| 141 |
+
f"[{index:03d}/{total_cases:03d}] {case.case_id} "
|
| 142 |
+
f"token_f1={row['scores']['token_f1']:.4f} "
|
| 143 |
+
f"rougeL={row['scores'].get('rougeL_f1') or 0:.4f} "
|
| 144 |
+
f"words={row['scores']['word_count']} "
|
| 145 |
+
f"latency_ms={row['latency_ms']:.2f}",
|
| 146 |
+
flush=True,
|
| 147 |
+
)
|
| 148 |
if args.save_every > 0 and len(rows) % args.save_every == 0:
|
| 149 |
write_progress(output_path, rows, argparse.Namespace(
|
| 150 |
provider="hf-transformers",
|
|
|
|
| 152 |
dataset=args.dataset,
|
| 153 |
prompt_style=args.prompt_style,
|
| 154 |
), final=False)
|
| 155 |
+
print(f"Saved progress at {len(rows)}/{total_cases} cases -> {output_path}", flush=True)
|
| 156 |
if args.verbose:
|
| 157 |
print(
|
| 158 |
+
f"PRED: {prediction[:300]}",
|
| 159 |
+
flush=True,
|
|
|
|
|
|
|
|
|
|
| 160 |
)
|
| 161 |
|
| 162 |
compute_bertscore(rows, args.disable_bertscore, args.bertscore_model)
|