| |
| import argparse |
| import json |
| from pathlib import Path |
| from tqdm import tqdm |
| import sys |
| import torch |
| from sdg_predict.inference import load_model, predict |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Batch inference using Hugging Face model.") |
| parser.add_argument("input", type=Path, help="Input JSONL file") |
| parser.add_argument("--key", type=str, required=True, help="JSON key with text input") |
| parser.add_argument("--batch_size", type=int, default=8, help="Batch size") |
| parser.add_argument("--model", type=str, default="simon-clmtd/sdg-scibert-zo_up", help="Model name on the Hub") |
| parser.add_argument("--top1", action="store_true", help="Return only top prediction") |
| parser.add_argument("--output", type=Path, help="Output file (optional, otherwise stdout)") |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer, model = load_model(args.model, device) |
|
|
| with args.input.open() as f: |
| texts = [] |
| rows = [] |
| for line in f: |
| row = json.loads(line) |
| if args.key not in row: |
| continue |
| texts.append(row[args.key]) |
| rows.append(row) |
|
|
| predictions = predict( |
| texts, |
| tokenizer, |
| model, |
| device, |
| batch_size=args.batch_size, |
| return_all_scores=not args.top1 |
| ) |
|
|
| output_stream = args.output.open("w") if args.output else sys.stdout |
| for row, pred in zip(rows, predictions): |
| row["prediction"] = pred |
| print(json.dumps(row, ensure_ascii=False), file=output_stream) |
| if args.output: |
| output_stream.close() |
|
|