Simon Clematide
Refactor CLI prediction script to enhance argument parsing and modularize inference logic. Add excel generation
f9c9b95
| # sdg_predict/inference.py | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| import torch | |
| import logging | |
| import json | |
| def load_model(model_name, device): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device) | |
| model.eval() | |
| return tokenizer, model | |
| def batched(iterable, batch_size): | |
| for i in range(0, len(iterable), batch_size): | |
| yield iterable[i : i + batch_size] | |
| def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True): | |
| classifier = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| batch_size=batch_size, | |
| truncation=True, | |
| padding=True, | |
| max_length=512, | |
| top_k=None if return_all_scores else 1, | |
| ) | |
| results = classifier(texts) | |
| if return_all_scores: | |
| for result in results: | |
| for score in result: | |
| score["score"] = round( | |
| score["score"], 3 | |
| ) # Round scores to 3 decimal places | |
| else: | |
| for result in results: | |
| result["score"] = round( | |
| result["score"], 3 | |
| ) # Round top score to 3 decimal places | |
| return results | |
| def binary_from_softmax(prediction, cap_class0=0.5): | |
| score_0 = next((x["score"] for x in prediction if x["label"] == "0"), 0.0) | |
| score_0 = min(score_0, cap_class0) | |
| binary_predictions = { | |
| label: 0.0 for label in map(str, range(1, 18)) | |
| } # Initialize all labels to 0.0 | |
| for entry in prediction: | |
| label = entry["label"] | |
| if label == "0": | |
| continue | |
| score = entry["score"] | |
| binary_score = score / (score + score_0) if (score + score_0) > 0 else 0.0 | |
| binary_predictions[label] = round(binary_score, 3) | |
| return binary_predictions | |
| def setup_device(): | |
| logging.info("Setting up device") | |
| if torch.backends.mps.is_available(): | |
| logging.info("Using MPS device") | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| logging.info("Using CUDA device") | |
| return torch.device("cuda") | |
| else: | |
| logging.info("Using CPU device") | |
| return torch.device("cpu") | |
| def load_model_and_tokenizer(model_name, device): | |
| logging.info("Loading model: %s", model_name) | |
| tokenizer, model = load_model(model_name, device) | |
| logging.info("Model loaded successfully") | |
| return tokenizer, model | |
| def load_input_data(input, key): | |
| logging.info("Loading input data from %s", input) | |
| texts = [] | |
| rows = [] | |
| with input.open() as f: | |
| for line in f: | |
| row = json.loads(line) | |
| if key not in row: | |
| continue | |
| texts.append(row[key]) | |
| logging.debug("Text: %s", row[key]) | |
| rows.append(row) | |
| logging.info("Loaded %d rows of input data", len(rows)) | |
| return texts, rows | |
| def perform_predictions(texts, tokenizer, model, device, batch_size, top1): | |
| logging.info("Starting predictions on %d texts", len(texts)) | |
| predictions = predict( | |
| texts, | |
| tokenizer, | |
| model, | |
| device, | |
| batch_size=batch_size, | |
| return_all_scores=not top1, | |
| ) | |
| logging.info("Predictions completed") | |
| return predictions | |