Simon Clematide
Refactor CLI prediction script to enhance argument parsing and modularize inference logic. Add excel generation
f9c9b95
| # sdg_predict/cli_predict.py | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import List, Dict, Union | |
| from sdg_predict.inference import ( | |
| load_model_and_tokenizer, | |
| load_input_data, | |
| perform_predictions, | |
| setup_device, | |
| binary_from_softmax, | |
| ) | |
| import logging | |
| import pandas as pd | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True | |
| ) | |
| def parse_arguments() -> argparse.Namespace: | |
| """ | |
| Parse command-line arguments for the script. | |
| Returns: | |
| Parsed arguments as a Namespace object. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="Batch inference using Hugging Face model." | |
| ) | |
| parser.add_argument("input", type=Path, help="Input JSONL file (default: None)") | |
| parser.add_argument( | |
| "--key", | |
| type=str, | |
| default="text", | |
| help="JSON key with text input (default: 'text')", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", "-b", type=int, default=8, help="Batch size (default: 8)" | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default="simon-clmtd/sdg-scibert-zo_up", | |
| help="Model name on the Hub (default: 'simon-clmtd/sdg-scibert-zo_up')", | |
| ) | |
| parser.add_argument( | |
| "--top1", | |
| action="store_true", | |
| help="Return only top prediction (default: False)", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| "-o", | |
| type=Path, | |
| help="Output file (default: None, otherwise stdout)", | |
| ) | |
| parser.add_argument( | |
| "--binarization", | |
| type=str, | |
| choices=["one-vs-all", "one-vs-0"], | |
| default="one-vs-0", | |
| help="Binarization method: 'one-vs-all' or 'one-vs-0' (default: 'one-vs-0')", | |
| ) | |
| parser.add_argument( | |
| "--sdg0-cap-prob", | |
| type=float, | |
| default=0.5, | |
| help=( | |
| "Maximum score allowed for class 0 in 'one-vs-0' binarization (default:" | |
| " 0.5)" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--excel", | |
| "-e", | |
| type=Path, | |
| help="Path to the Excel file for binary predictions (optional)", | |
| ) | |
| return parser.parse_args() | |
| def main( | |
| input: Path, | |
| key: str, | |
| batch_size: int, | |
| model: str, | |
| top1: bool, | |
| output: Union[Path, None], | |
| binarization: str, | |
| sdg0_cap_prob: float, | |
| excel: Union[Path, None], | |
| ) -> None: | |
| """ | |
| Main function to perform batch inference using a Hugging Face model. | |
| Args: | |
| input: Path to the input JSONL file. | |
| key: JSON key containing the text input. | |
| batch_size: Batch size for inference. | |
| model: Model name or path. | |
| top1: Whether to return only the top prediction. | |
| output: Path to the output file (optional). | |
| binarization: Binarization method ('one-vs-all' or 'one-vs-0'). | |
| sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization. | |
| excel: Path to the Excel file for binary predictions (optional). | |
| """ | |
| logging.info("Starting main function") | |
| device = setup_device() | |
| tokenizer, model = load_model_and_tokenizer(model, device) | |
| texts, rows = load_input_data(input, key) | |
| predictions = perform_predictions(texts, tokenizer, model, device, batch_size, top1) | |
| write_output(rows, predictions, output, binarization, sdg0_cap_prob, excel) | |
| logging.info("Main function completed") | |
| def write_output( | |
| rows: List[Dict], | |
| predictions: List, | |
| output: Union[Path, None], | |
| binarization: str, | |
| sdg0_cap_prob: float, | |
| excel: Union[Path, None] = None, | |
| ) -> None: | |
| """ | |
| Write the predictions to the output file or stdout, and optionally to an Excel file. | |
| Args: | |
| rows: List of input rows. | |
| predictions: List of predictions. | |
| output: Path to the output file (optional). | |
| binarization: Binarization method ('one-vs-all' or 'one-vs-0'). | |
| sdg0_cap_prob: Maximum score allowed for class 0 in 'one-vs-0' binarization. | |
| excel: Path to the Excel file (optional). | |
| """ | |
| logging.info("Writing output to %s", output or "stdout") | |
| output_stream = output.open("w") if output else None | |
| transformed_data = [] | |
| for row, pred in zip(rows, predictions): | |
| if binarization == "one-vs-all": | |
| binary_predictions = { | |
| str(label): round( | |
| next((x["score"] for x in pred if int(x["label"]) == label), 0), 3 | |
| ) | |
| for label in range(1, 18) | |
| } | |
| elif binarization == "one-vs-0": | |
| binary_predictions = binary_from_softmax(pred, sdg0_cap_prob) | |
| output_row = { | |
| "id": row.get("id"), | |
| "text": row.get("text"), | |
| "prediction": pred, | |
| "binary_predictions": binary_predictions, | |
| } | |
| transformed_data.append( | |
| { | |
| "publication_zora_id": row.get("id"), | |
| **{ | |
| f"dvdblk_sdg{sdg}": binary_predictions.get(str(sdg), 0) | |
| for sdg in range(1, 18) | |
| }, | |
| } | |
| ) | |
| print(json.dumps(output_row, ensure_ascii=False), file=output_stream) | |
| if output: | |
| output_stream.close() | |
| logging.info("Output written to %s", output) | |
| if excel: | |
| logging.info("Writing Excel output to %s", excel) | |
| df_transformed = pd.DataFrame(transformed_data) | |
| df_transformed.to_excel(excel, index=False) | |
| logging.info("Excel output written to %s", excel) | |
| logging.info("Output writing completed") | |
| if __name__ == "__main__": | |
| args = parse_arguments() | |
| main( | |
| input=args.input, | |
| key=args.key, | |
| batch_size=args.batch_size, | |
| model=args.model, | |
| top1=args.top1, | |
| output=args.output, | |
| binarization=args.binarization, | |
| sdg0_cap_prob=args.sdg0_cap_prob, | |
| excel=args.excel, | |
| ) | |