Simon Clematide commited on
Commit
f9c9b95
·
1 Parent(s): cef9aa2

Refactor CLI prediction script to enhance argument parsing and modularize inference logic. Add excel generation

Browse files
Files changed (2) hide show
  1. sdg_predict/cli_predict.py +157 -64
  2. sdg_predict/inference.py +70 -0
sdg_predict/cli_predict.py CHANGED
@@ -2,11 +2,17 @@
2
  import argparse
3
  import json
4
  from pathlib import Path
5
- from tqdm import tqdm
6
- import sys
7
- import torch
8
- from sdg_predict.inference import load_model, predict
 
 
 
 
 
9
  import logging
 
10
 
11
  # Set up logging
12
  logging.basicConfig(
@@ -14,80 +20,138 @@ logging.basicConfig(
14
  )
15
 
16
 
17
- def main():
 
 
 
 
 
 
18
  parser = argparse.ArgumentParser(
19
  description="Batch inference using Hugging Face model."
20
  )
21
- parser.add_argument("input", type=Path, help="Input JSONL file")
22
  parser.add_argument(
23
- "--key", type=str, default="text", help="JSON key with text input"
 
 
 
 
 
 
24
  )
25
- parser.add_argument("--batch_size", "-b", type=int, default=8, help="Batch size")
26
  parser.add_argument(
27
  "--model",
28
  type=str,
29
  default="simon-clmtd/sdg-scibert-zo_up",
30
- help="Model name on the Hub",
31
  )
32
  parser.add_argument(
33
- "--top1", action="store_true", help="Return only top prediction"
 
 
34
  )
35
  parser.add_argument(
36
- "--output", type=Path, help="Output file (optional, otherwise stdout)"
 
 
 
37
  )
38
- args = parser.parse_args()
39
-
40
- # -------------------------------
41
- # 1. Device Setup (MPS support for Apple Silicon)
42
- # -------------------------------
43
- if torch.backends.mps.is_available():
44
- device = torch.device("mps")
45
- logging.info("Using MPS device")
46
- elif torch.cuda.is_available():
47
- device = torch.device("cuda")
48
- logging.info("Using CUDA device")
49
- else:
50
- device = torch.device("cpu")
51
- logging.info("Using CPU device")
52
- # device = torch.device("cpu")
53
- logging.info("Loading model: %s", args.model)
54
- tokenizer, model = load_model(args.model, device)
55
- logging.info("Model loaded successfully")
56
-
57
- with args.input.open() as f:
58
- texts = []
59
- rows = []
60
- for line in f:
61
- row = json.loads(line)
62
- if args.key not in row:
63
- continue
64
- texts.append(row[args.key])
65
- logging.debug("Text: %s", row[args.key])
66
- rows.append(row)
67
-
68
- logging.info("Starting predictions on %d texts", len(texts))
69
- predictions = predict(
70
- texts,
71
- tokenizer,
72
- model,
73
- device,
74
- batch_size=args.batch_size,
75
- return_all_scores=not args.top1,
76
  )
77
- logging.info("Predictions completed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- output_stream = args.output.open("w") if args.output else sys.stdout
80
  for row, pred in zip(rows, predictions):
81
- # Compute binary probabilities for labels 1-17
82
- binary_predictions = {}
83
- for label_data in pred:
84
- label_data["score"] = round(
85
- label_data["score"], 3
86
- ) # Round prediction scores to 3 decimal places
87
- label = int(label_data["label"])
88
- if 1 <= label <= 17:
89
- binary_prob = label_data["score"] # Already rounded
90
- binary_predictions[str(label)] = binary_prob
91
 
92
  output_row = {
93
  "id": row.get("id"),
@@ -95,11 +159,40 @@ def main():
95
  "prediction": pred,
96
  "binary_predictions": binary_predictions,
97
  }
 
 
 
 
 
 
 
 
 
98
  print(json.dumps(output_row, ensure_ascii=False), file=output_stream)
99
- if args.output:
 
100
  output_stream.close()
101
- logging.info("Output written to %s", args.output)
 
 
 
 
 
 
 
 
102
 
103
 
104
  if __name__ == "__main__":
105
- main()
 
 
 
 
 
 
 
 
 
 
 
 
2
  import argparse
3
  import json
4
  from pathlib import Path
5
+ from typing import List, Dict, Union
6
+
7
+ from sdg_predict.inference import (
8
+ load_model_and_tokenizer,
9
+ load_input_data,
10
+ perform_predictions,
11
+ setup_device,
12
+ binary_from_softmax,
13
+ )
14
  import logging
15
+ import pandas as pd
16
 
17
  # Set up logging
18
  logging.basicConfig(
 
20
  )
21
 
22
 
23
+ def parse_arguments() -> argparse.Namespace:
24
+ """
25
+ Parse command-line arguments for the script.
26
+
27
+ Returns:
28
+ Parsed arguments as a Namespace object.
29
+ """
30
  parser = argparse.ArgumentParser(
31
  description="Batch inference using Hugging Face model."
32
  )
33
+ parser.add_argument("input", type=Path, help="Input JSONL file (default: None)")
34
  parser.add_argument(
35
+ "--key",
36
+ type=str,
37
+ default="text",
38
+ help="JSON key with text input (default: 'text')",
39
+ )
40
+ parser.add_argument(
41
+ "--batch_size", "-b", type=int, default=8, help="Batch size (default: 8)"
42
  )
 
43
  parser.add_argument(
44
  "--model",
45
  type=str,
46
  default="simon-clmtd/sdg-scibert-zo_up",
47
+ help="Model name on the Hub (default: 'simon-clmtd/sdg-scibert-zo_up')",
48
  )
49
  parser.add_argument(
50
+ "--top1",
51
+ action="store_true",
52
+ help="Return only top prediction (default: False)",
53
  )
54
  parser.add_argument(
55
+ "--output",
56
+ "-o",
57
+ type=Path,
58
+ help="Output file (default: None, otherwise stdout)",
59
  )
60
+ parser.add_argument(
61
+ "--binarization",
62
+ type=str,
63
+ choices=["one-vs-all", "one-vs-0"],
64
+ default="one-vs-0",
65
+ help="Binarization method: 'one-vs-all' or 'one-vs-0' (default: 'one-vs-0')",
66
+ )
67
+ parser.add_argument(
68
+ "--sdg0-cap-prob",
69
+ type=float,
70
+ default=0.5,
71
+ help=(
72
+ "Maximum score allowed for class 0 in 'one-vs-0' binarization (default:"
73
+ " 0.5)"
74
+ ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
+ parser.add_argument(
77
+ "--excel",
78
+ "-e",
79
+ type=Path,
80
+ help="Path to the Excel file for binary predictions (optional)",
81
+ )
82
+ return parser.parse_args()
83
+
84
+
85
+ def main(
86
+ input: Path,
87
+ key: str,
88
+ batch_size: int,
89
+ model: str,
90
+ top1: bool,
91
+ output: Union[Path, None],
92
+ binarization: str,
93
+ sdg0_cap_prob: float,
94
+ excel: Union[Path, None],
95
+ ) -> None:
96
+ """
97
+ Main function to perform batch inference using a Hugging Face model.
98
+
99
+ Args:
100
+ input: Path to the input JSONL file.
101
+ key: JSON key containing the text input.
102
+ batch_size: Batch size for inference.
103
+ model: Model name or path.
104
+ top1: Whether to return only the top prediction.
105
+ output: Path to the output file (optional).
106
+ binarization: Binarization method ('one-vs-all' or 'one-vs-0').
107
+ sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization.
108
+ excel: Path to the Excel file for binary predictions (optional).
109
+ """
110
+ logging.info("Starting main function")
111
+ device = setup_device()
112
+ tokenizer, model = load_model_and_tokenizer(model, device)
113
+
114
+ texts, rows = load_input_data(input, key)
115
+
116
+ predictions = perform_predictions(texts, tokenizer, model, device, batch_size, top1)
117
+
118
+ write_output(rows, predictions, output, binarization, sdg0_cap_prob, excel)
119
+ logging.info("Main function completed")
120
+
121
+
122
+ def write_output(
123
+ rows: List[Dict],
124
+ predictions: List,
125
+ output: Union[Path, None],
126
+ binarization: str,
127
+ sdg0_cap_prob: float,
128
+ excel: Union[Path, None] = None,
129
+ ) -> None:
130
+ """
131
+ Write the predictions to the output file or stdout, and optionally to an Excel file.
132
+
133
+ Args:
134
+ rows: List of input rows.
135
+ predictions: List of predictions.
136
+ output: Path to the output file (optional).
137
+ binarization: Binarization method ('one-vs-all' or 'one-vs-0').
138
+ sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization.
139
+ excel: Path to the Excel file (optional).
140
+ """
141
+ logging.info("Writing output to %s", output or "stdout")
142
+ output_stream = output.open("w") if output else None
143
+ transformed_data = []
144
 
 
145
  for row, pred in zip(rows, predictions):
146
+ if binarization == "one-vs-all":
147
+ binary_predictions = {
148
+ str(label): round(
149
+ next((x["score"] for x in pred if int(x["label"]) == label), 0), 3
150
+ )
151
+ for label in range(1, 18)
152
+ }
153
+ elif binarization == "one-vs-0":
154
+ binary_predictions = binary_from_softmax(pred, sdg0_cap_prob)
 
155
 
156
  output_row = {
157
  "id": row.get("id"),
 
159
  "prediction": pred,
160
  "binary_predictions": binary_predictions,
161
  }
162
+ transformed_data.append(
163
+ {
164
+ "publication_zora_id": row.get("id"),
165
+ **{
166
+ f"dvdblk_sdg{sdg}": binary_predictions.get(str(sdg), 0)
167
+ for sdg in range(1, 18)
168
+ },
169
+ }
170
+ )
171
  print(json.dumps(output_row, ensure_ascii=False), file=output_stream)
172
+
173
+ if output:
174
  output_stream.close()
175
+ logging.info("Output written to %s", output)
176
+
177
+ if excel:
178
+ logging.info("Writing Excel output to %s", excel)
179
+ df_transformed = pd.DataFrame(transformed_data)
180
+ df_transformed.to_excel(excel, index=False)
181
+ logging.info("Excel output written to %s", excel)
182
+
183
+ logging.info("Output writing completed")
184
 
185
 
186
  if __name__ == "__main__":
187
+ args = parse_arguments()
188
+ main(
189
+ input=args.input,
190
+ key=args.key,
191
+ batch_size=args.batch_size,
192
+ model=args.model,
193
+ top1=args.top1,
194
+ output=args.output,
195
+ binarization=args.binarization,
196
+ sdg0_cap_prob=args.sdg0_cap_prob,
197
+ excel=args.excel,
198
+ )
sdg_predict/inference.py CHANGED
@@ -2,6 +2,7 @@
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
  import torch
4
  import logging
 
5
 
6
 
7
  def load_model(model_name, device):
@@ -43,3 +44,72 @@ def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=Tru
43
  ) # Round top score to 3 decimal places
44
 
45
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
3
  import torch
4
  import logging
5
+ import json
6
 
7
 
8
  def load_model(model_name, device):
 
44
  ) # Round top score to 3 decimal places
45
 
46
  return results
47
+
48
+
49
+ def binary_from_softmax(prediction, cap_class0=0.5):
50
+ score_0 = next((x["score"] for x in prediction if x["label"] == "0"), 0.0)
51
+ score_0 = min(score_0, cap_class0)
52
+
53
+ binary_predictions = {
54
+ label: 0.0 for label in map(str, range(1, 18))
55
+ } # Initialize all labels to 0.0
56
+
57
+ for entry in prediction:
58
+ label = entry["label"]
59
+ if label == "0":
60
+ continue
61
+ score = entry["score"]
62
+ binary_score = score / (score + score_0) if (score + score_0) > 0 else 0.0
63
+ binary_predictions[label] = round(binary_score, 3)
64
+
65
+ return binary_predictions
66
+
67
+
68
+ def setup_device():
69
+ logging.info("Setting up device")
70
+ if torch.backends.mps.is_available():
71
+ logging.info("Using MPS device")
72
+ return torch.device("mps")
73
+ elif torch.cuda.is_available():
74
+ logging.info("Using CUDA device")
75
+ return torch.device("cuda")
76
+ else:
77
+ logging.info("Using CPU device")
78
+ return torch.device("cpu")
79
+
80
+
81
+ def load_model_and_tokenizer(model_name, device):
82
+ logging.info("Loading model: %s", model_name)
83
+ tokenizer, model = load_model(model_name, device)
84
+ logging.info("Model loaded successfully")
85
+ return tokenizer, model
86
+
87
+
88
+ def load_input_data(input, key):
89
+ logging.info("Loading input data from %s", input)
90
+ texts = []
91
+ rows = []
92
+ with input.open() as f:
93
+ for line in f:
94
+ row = json.loads(line)
95
+ if key not in row:
96
+ continue
97
+ texts.append(row[key])
98
+ logging.debug("Text: %s", row[key])
99
+ rows.append(row)
100
+ logging.info("Loaded %d rows of input data", len(rows))
101
+ return texts, rows
102
+
103
+
104
+ def perform_predictions(texts, tokenizer, model, device, batch_size, top1):
105
+ logging.info("Starting predictions on %d texts", len(texts))
106
+ predictions = predict(
107
+ texts,
108
+ tokenizer,
109
+ model,
110
+ device,
111
+ batch_size=batch_size,
112
+ return_all_scores=not top1,
113
+ )
114
+ logging.info("Predictions completed")
115
+ return predictions