sdg-scibert-zo_up / sdg_predict /cli_predict.py
Simon Clematide
Add CLI and inference modules for batch prediction using Hugging Face model
9d36a4d
raw
history blame
1.7 kB
# sdg_predict/cli_predict.py
import argparse
import json
from pathlib import Path
from tqdm import tqdm
import sys
import torch
from sdg_predict.inference import load_model, predict
def main():
parser = argparse.ArgumentParser(description="Batch inference using Hugging Face model.")
parser.add_argument("input", type=Path, help="Input JSONL file")
parser.add_argument("--key", type=str, required=True, help="JSON key with text input")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--model", type=str, default="simon-clmtd/sdg-scibert-zo_up", help="Model name on the Hub")
parser.add_argument("--top1", action="store_true", help="Return only top prediction")
parser.add_argument("--output", type=Path, help="Output file (optional, otherwise stdout)")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer, model = load_model(args.model, device)
with args.input.open() as f:
texts = []
rows = []
for line in f:
row = json.loads(line)
if args.key not in row:
continue
texts.append(row[args.key])
rows.append(row)
predictions = predict(
texts,
tokenizer,
model,
device,
batch_size=args.batch_size,
return_all_scores=not args.top1
)
output_stream = args.output.open("w") if args.output else sys.stdout
for row, pred in zip(rows, predictions):
row["prediction"] = pred
print(json.dumps(row, ensure_ascii=False), file=output_stream)
if args.output:
output_stream.close()