# sdg_predict/cli_predict.py import argparse import json from pathlib import Path from typing import List, Dict, Union from sdg_predict.inference import ( load_model_and_tokenizer, load_input_data, perform_predictions, setup_device, binary_from_softmax, ) import logging import pandas as pd # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True ) def parse_arguments() -> argparse.Namespace: """ Parse command-line arguments for the script. Returns: Parsed arguments as a Namespace object. """ parser = argparse.ArgumentParser( description="Batch inference using Hugging Face model." ) parser.add_argument("input", type=Path, help="Input JSONL file (default: None)") parser.add_argument( "--key", type=str, default="text", help="JSON key with text input (default: 'text')", ) parser.add_argument( "--batch_size", "-b", type=int, default=8, help="Batch size (default: 8)" ) parser.add_argument( "--model", type=str, default="simon-clmtd/sdg-scibert-zo_up", help="Model name on the Hub (default: 'simon-clmtd/sdg-scibert-zo_up')", ) parser.add_argument( "--top1", action="store_true", help="Return only top prediction (default: False)", ) parser.add_argument( "--output", "-o", type=Path, help="Output file (default: None, otherwise stdout)", ) parser.add_argument( "--binarization", type=str, choices=["one-vs-all", "one-vs-0"], default="one-vs-0", help="Binarization method: 'one-vs-all' or 'one-vs-0' (default: 'one-vs-0')", ) parser.add_argument( "--sdg0-cap-prob", type=float, default=0.5, help=( "Maximum score allowed for class 0 in 'one-vs-0' binarization (default:" " 0.5)" ), ) parser.add_argument( "--excel", "-e", type=Path, help="Path to the Excel file for binary predictions (optional)", ) return parser.parse_args() def main( input: Path, key: str, batch_size: int, model: str, top1: bool, output: Union[Path, None], binarization: str, sdg0_cap_prob: float, excel: Union[Path, None], ) -> None: """ Main function to perform batch inference using a Hugging Face model. Args: input: Path to the input JSONL file. key: JSON key containing the text input. batch_size: Batch size for inference. model: Model name or path. top1: Whether to return only the top prediction. output: Path to the output file (optional). binarization: Binarization method ('one-vs-all' or 'one-vs-0'). sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization. excel: Path to the Excel file for binary predictions (optional). """ logging.info("Starting main function") device = setup_device() tokenizer, model = load_model_and_tokenizer(model, device) texts, rows = load_input_data(input, key) predictions = perform_predictions(texts, tokenizer, model, device, batch_size, top1) write_output(rows, predictions, output, binarization, sdg0_cap_prob, excel) logging.info("Main function completed") def write_output( rows: List[Dict], predictions: List, output: Union[Path, None], binarization: str, sdg0_cap_prob: float, excel: Union[Path, None] = None, ) -> None: """ Write the predictions to the output file or stdout, and optionally to an Excel file. Args: rows: List of input rows. predictions: List of predictions. output: Path to the output file (optional). binarization: Binarization method ('one-vs-all' or 'one-vs-0'). sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization. excel: Path to the Excel file (optional). """ logging.info("Writing output to %s", output or "stdout") output_stream = output.open("w") if output else None transformed_data = [] for row, pred in zip(rows, predictions): if binarization == "one-vs-all": binary_predictions = { str(label): round( next((x["score"] for x in pred if int(x["label"]) == label), 0), 3 ) for label in range(1, 18) } elif binarization == "one-vs-0": binary_predictions = binary_from_softmax(pred, sdg0_cap_prob) output_row = { "id": row.get("id"), "text": row.get("text"), "prediction": pred, "binary_predictions": binary_predictions, } transformed_data.append( { "publication_zora_id": row.get("id"), **{ f"dvdblk_sdg{sdg}": binary_predictions.get(str(sdg), 0) for sdg in range(1, 18) }, } ) print(json.dumps(output_row, ensure_ascii=False), file=output_stream) if output: output_stream.close() logging.info("Output written to %s", output) if excel: logging.info("Writing Excel output to %s", excel) df_transformed = pd.DataFrame(transformed_data) df_transformed.to_excel(excel, index=False) logging.info("Excel output written to %s", excel) logging.info("Output writing completed") if __name__ == "__main__": args = parse_arguments() main( input=args.input, key=args.key, batch_size=args.batch_size, model=args.model, top1=args.top1, output=args.output, binarization=args.binarization, sdg0_cap_prob=args.sdg0_cap_prob, excel=args.excel, )