Simon Clematide commited on
Commit
9d36a4d
·
1 Parent(s): a36ddd3

Add CLI and inference modules for batch prediction using Hugging Face model

Browse files
sdg_predict/__init__.py ADDED
File without changes
sdg_predict/cli_predict.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sdg_predict/cli_predict.py
2
+ import argparse
3
+ import json
4
+ from pathlib import Path
5
+ from tqdm import tqdm
6
+ import sys
7
+ import torch
8
+ from sdg_predict.inference import load_model, predict
9
+
10
+ def main():
11
+ parser = argparse.ArgumentParser(description="Batch inference using Hugging Face model.")
12
+ parser.add_argument("input", type=Path, help="Input JSONL file")
13
+ parser.add_argument("--key", type=str, required=True, help="JSON key with text input")
14
+ parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
15
+ parser.add_argument("--model", type=str, default="simon-clmtd/sdg-scibert-zo_up", help="Model name on the Hub")
16
+ parser.add_argument("--top1", action="store_true", help="Return only top prediction")
17
+ parser.add_argument("--output", type=Path, help="Output file (optional, otherwise stdout)")
18
+ args = parser.parse_args()
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ tokenizer, model = load_model(args.model, device)
22
+
23
+ with args.input.open() as f:
24
+ texts = []
25
+ rows = []
26
+ for line in f:
27
+ row = json.loads(line)
28
+ if args.key not in row:
29
+ continue
30
+ texts.append(row[args.key])
31
+ rows.append(row)
32
+
33
+ predictions = predict(
34
+ texts,
35
+ tokenizer,
36
+ model,
37
+ device,
38
+ batch_size=args.batch_size,
39
+ return_all_scores=not args.top1
40
+ )
41
+
42
+ output_stream = args.output.open("w") if args.output else sys.stdout
43
+ for row, pred in zip(rows, predictions):
44
+ row["prediction"] = pred
45
+ print(json.dumps(row, ensure_ascii=False), file=output_stream)
46
+ if args.output:
47
+ output_stream.close()
sdg_predict/inference.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sdg_predict/inference.py
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+
5
+ def load_model(model_name, device):
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
8
+ model.eval()
9
+ return tokenizer, model
10
+
11
+ def batched(iterable, batch_size):
12
+ for i in range(0, len(iterable), batch_size):
13
+ yield iterable[i:i + batch_size]
14
+
15
+ def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True):
16
+ results = []
17
+ for batch_texts in batched(texts, batch_size):
18
+ inputs = tokenizer(
19
+ batch_texts,
20
+ return_tensors="pt",
21
+ padding=True,
22
+ truncation=True,
23
+ max_length=512
24
+ ).to(device)
25
+
26
+ with torch.no_grad():
27
+ logits = model(**inputs).logits
28
+ probs = torch.nn.functional.softmax(logits, dim=-1)
29
+
30
+ for prob in probs:
31
+ if return_all_scores:
32
+ results.append([
33
+ {"label": model.config.id2label[i], "score": prob[i].item()}
34
+ for i in range(len(prob))
35
+ ])
36
+ else:
37
+ top = torch.argmax(prob).item()
38
+ results.append({
39
+ "label": model.config.id2label[top],
40
+ "score": prob[top].item()
41
+ })
42
+ return results
setup.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="sdg-predict",
5
+ version="0.1",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "transformers>=4.36",
9
+ "torch>=2.0",
10
+ "tqdm",
11
+ ],
12
+ entry_points={
13
+ "console_scripts": [
14
+ "sdg-predict = sdg_predict.cli_predict:main"
15
+ ]
16
+ },
17
+ author="Simon Clematide",
18
+ description="Command-line prediction for SDG SciBERT classifier",
19
+ python_requires=">=3.8",
20
+ )