Simon Clematide
commited on
Commit
·
9d36a4d
1
Parent(s):
a36ddd3
Add CLI and inference modules for batch prediction using Hugging Face model
Browse files- sdg_predict/__init__.py +0 -0
- sdg_predict/cli_predict.py +47 -0
- sdg_predict/inference.py +42 -0
- setup.py +20 -0
sdg_predict/__init__.py
ADDED
|
File without changes
|
sdg_predict/cli_predict.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sdg_predict/cli_predict.py
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import sys
|
| 7 |
+
import torch
|
| 8 |
+
from sdg_predict.inference import load_model, predict
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
parser = argparse.ArgumentParser(description="Batch inference using Hugging Face model.")
|
| 12 |
+
parser.add_argument("input", type=Path, help="Input JSONL file")
|
| 13 |
+
parser.add_argument("--key", type=str, required=True, help="JSON key with text input")
|
| 14 |
+
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
|
| 15 |
+
parser.add_argument("--model", type=str, default="simon-clmtd/sdg-scibert-zo_up", help="Model name on the Hub")
|
| 16 |
+
parser.add_argument("--top1", action="store_true", help="Return only top prediction")
|
| 17 |
+
parser.add_argument("--output", type=Path, help="Output file (optional, otherwise stdout)")
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
tokenizer, model = load_model(args.model, device)
|
| 22 |
+
|
| 23 |
+
with args.input.open() as f:
|
| 24 |
+
texts = []
|
| 25 |
+
rows = []
|
| 26 |
+
for line in f:
|
| 27 |
+
row = json.loads(line)
|
| 28 |
+
if args.key not in row:
|
| 29 |
+
continue
|
| 30 |
+
texts.append(row[args.key])
|
| 31 |
+
rows.append(row)
|
| 32 |
+
|
| 33 |
+
predictions = predict(
|
| 34 |
+
texts,
|
| 35 |
+
tokenizer,
|
| 36 |
+
model,
|
| 37 |
+
device,
|
| 38 |
+
batch_size=args.batch_size,
|
| 39 |
+
return_all_scores=not args.top1
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
output_stream = args.output.open("w") if args.output else sys.stdout
|
| 43 |
+
for row, pred in zip(rows, predictions):
|
| 44 |
+
row["prediction"] = pred
|
| 45 |
+
print(json.dumps(row, ensure_ascii=False), file=output_stream)
|
| 46 |
+
if args.output:
|
| 47 |
+
output_stream.close()
|
sdg_predict/inference.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sdg_predict/inference.py
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
def load_model(model_name, device):
|
| 6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 7 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
|
| 8 |
+
model.eval()
|
| 9 |
+
return tokenizer, model
|
| 10 |
+
|
| 11 |
+
def batched(iterable, batch_size):
|
| 12 |
+
for i in range(0, len(iterable), batch_size):
|
| 13 |
+
yield iterable[i:i + batch_size]
|
| 14 |
+
|
| 15 |
+
def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True):
|
| 16 |
+
results = []
|
| 17 |
+
for batch_texts in batched(texts, batch_size):
|
| 18 |
+
inputs = tokenizer(
|
| 19 |
+
batch_texts,
|
| 20 |
+
return_tensors="pt",
|
| 21 |
+
padding=True,
|
| 22 |
+
truncation=True,
|
| 23 |
+
max_length=512
|
| 24 |
+
).to(device)
|
| 25 |
+
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
logits = model(**inputs).logits
|
| 28 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 29 |
+
|
| 30 |
+
for prob in probs:
|
| 31 |
+
if return_all_scores:
|
| 32 |
+
results.append([
|
| 33 |
+
{"label": model.config.id2label[i], "score": prob[i].item()}
|
| 34 |
+
for i in range(len(prob))
|
| 35 |
+
])
|
| 36 |
+
else:
|
| 37 |
+
top = torch.argmax(prob).item()
|
| 38 |
+
results.append({
|
| 39 |
+
"label": model.config.id2label[top],
|
| 40 |
+
"score": prob[top].item()
|
| 41 |
+
})
|
| 42 |
+
return results
|
setup.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="sdg-predict",
|
| 5 |
+
version="0.1",
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
install_requires=[
|
| 8 |
+
"transformers>=4.36",
|
| 9 |
+
"torch>=2.0",
|
| 10 |
+
"tqdm",
|
| 11 |
+
],
|
| 12 |
+
entry_points={
|
| 13 |
+
"console_scripts": [
|
| 14 |
+
"sdg-predict = sdg_predict.cli_predict:main"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
author="Simon Clematide",
|
| 18 |
+
description="Command-line prediction for SDG SciBERT classifier",
|
| 19 |
+
python_requires=">=3.8",
|
| 20 |
+
)
|