Simon Clematide
Refactor CLI prediction script to enhance argument parsing and modularize inference logic. Add excel generation
f9c9b95
# sdg_predict/cli_predict.py
import argparse
import json
from pathlib import Path
from typing import List, Dict, Union
from sdg_predict.inference import (
load_model_and_tokenizer,
load_input_data,
perform_predictions,
setup_device,
binary_from_softmax,
)
import logging
import pandas as pd
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True
)
def parse_arguments() -> argparse.Namespace:
"""
Parse command-line arguments for the script.
Returns:
Parsed arguments as a Namespace object.
"""
parser = argparse.ArgumentParser(
description="Batch inference using Hugging Face model."
)
parser.add_argument("input", type=Path, help="Input JSONL file (default: None)")
parser.add_argument(
"--key",
type=str,
default="text",
help="JSON key with text input (default: 'text')",
)
parser.add_argument(
"--batch_size", "-b", type=int, default=8, help="Batch size (default: 8)"
)
parser.add_argument(
"--model",
type=str,
default="simon-clmtd/sdg-scibert-zo_up",
help="Model name on the Hub (default: 'simon-clmtd/sdg-scibert-zo_up')",
)
parser.add_argument(
"--top1",
action="store_true",
help="Return only top prediction (default: False)",
)
parser.add_argument(
"--output",
"-o",
type=Path,
help="Output file (default: None, otherwise stdout)",
)
parser.add_argument(
"--binarization",
type=str,
choices=["one-vs-all", "one-vs-0"],
default="one-vs-0",
help="Binarization method: 'one-vs-all' or 'one-vs-0' (default: 'one-vs-0')",
)
parser.add_argument(
"--sdg0-cap-prob",
type=float,
default=0.5,
help=(
"Maximum score allowed for class 0 in 'one-vs-0' binarization (default:"
" 0.5)"
),
)
parser.add_argument(
"--excel",
"-e",
type=Path,
help="Path to the Excel file for binary predictions (optional)",
)
return parser.parse_args()
def main(
input: Path,
key: str,
batch_size: int,
model: str,
top1: bool,
output: Union[Path, None],
binarization: str,
sdg0_cap_prob: float,
excel: Union[Path, None],
) -> None:
"""
Main function to perform batch inference using a Hugging Face model.
Args:
input: Path to the input JSONL file.
key: JSON key containing the text input.
batch_size: Batch size for inference.
model: Model name or path.
top1: Whether to return only the top prediction.
output: Path to the output file (optional).
binarization: Binarization method ('one-vs-all' or 'one-vs-0').
sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization.
excel: Path to the Excel file for binary predictions (optional).
"""
logging.info("Starting main function")
device = setup_device()
tokenizer, model = load_model_and_tokenizer(model, device)
texts, rows = load_input_data(input, key)
predictions = perform_predictions(texts, tokenizer, model, device, batch_size, top1)
write_output(rows, predictions, output, binarization, sdg0_cap_prob, excel)
logging.info("Main function completed")
def write_output(
rows: List[Dict],
predictions: List,
output: Union[Path, None],
binarization: str,
sdg0_cap_prob: float,
excel: Union[Path, None] = None,
) -> None:
"""
Write the predictions to the output file or stdout, and optionally to an Excel file.
Args:
rows: List of input rows.
predictions: List of predictions.
output: Path to the output file (optional).
binarization: Binarization method ('one-vs-all' or 'one-vs-0').
sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization.
excel: Path to the Excel file (optional).
"""
logging.info("Writing output to %s", output or "stdout")
output_stream = output.open("w") if output else None
transformed_data = []
for row, pred in zip(rows, predictions):
if binarization == "one-vs-all":
binary_predictions = {
str(label): round(
next((x["score"] for x in pred if int(x["label"]) == label), 0), 3
)
for label in range(1, 18)
}
elif binarization == "one-vs-0":
binary_predictions = binary_from_softmax(pred, sdg0_cap_prob)
output_row = {
"id": row.get("id"),
"text": row.get("text"),
"prediction": pred,
"binary_predictions": binary_predictions,
}
transformed_data.append(
{
"publication_zora_id": row.get("id"),
**{
f"dvdblk_sdg{sdg}": binary_predictions.get(str(sdg), 0)
for sdg in range(1, 18)
},
}
)
print(json.dumps(output_row, ensure_ascii=False), file=output_stream)
if output:
output_stream.close()
logging.info("Output written to %s", output)
if excel:
logging.info("Writing Excel output to %s", excel)
df_transformed = pd.DataFrame(transformed_data)
df_transformed.to_excel(excel, index=False)
logging.info("Excel output written to %s", excel)
logging.info("Output writing completed")
if __name__ == "__main__":
args = parse_arguments()
main(
input=args.input,
key=args.key,
batch_size=args.batch_size,
model=args.model,
top1=args.top1,
output=args.output,
binarization=args.binarization,
sdg0_cap_prob=args.sdg0_cap_prob,
excel=args.excel,
)