# sdg_predict/inference.py from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import torch import logging import json def load_model(model_name, device): tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False) model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device) model.eval() return tokenizer, model def batched(iterable, batch_size): for i in range(0, len(iterable), batch_size): yield iterable[i : i + batch_size] def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True): classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=device, batch_size=batch_size, truncation=True, padding=True, max_length=512, top_k=None if return_all_scores else 1, ) results = classifier(texts) if return_all_scores: for result in results: for score in result: score["score"] = round( score["score"], 3 ) # Round scores to 3 decimal places else: for result in results: result["score"] = round( result["score"], 3 ) # Round top score to 3 decimal places return results def binary_from_softmax(prediction, cap_class0=0.5): score_0 = next((x["score"] for x in prediction if x["label"] == "0"), 0.0) score_0 = min(score_0, cap_class0) binary_predictions = { label: 0.0 for label in map(str, range(1, 18)) } # Initialize all labels to 0.0 for entry in prediction: label = entry["label"] if label == "0": continue score = entry["score"] binary_score = score / (score + score_0) if (score + score_0) > 0 else 0.0 binary_predictions[label] = round(binary_score, 3) return binary_predictions def setup_device(): logging.info("Setting up device") if torch.backends.mps.is_available(): logging.info("Using MPS device") return torch.device("mps") elif torch.cuda.is_available(): logging.info("Using CUDA device") return torch.device("cuda") else: logging.info("Using CPU device") return torch.device("cpu") def load_model_and_tokenizer(model_name, device): logging.info("Loading model: %s", model_name) tokenizer, model = load_model(model_name, device) logging.info("Model loaded successfully") return tokenizer, model def load_input_data(input, key): logging.info("Loading input data from %s", input) texts = [] rows = [] with input.open() as f: for line in f: row = json.loads(line) if key not in row: continue texts.append(row[key]) logging.debug("Text: %s", row[key]) rows.append(row) logging.info("Loaded %d rows of input data", len(rows)) return texts, rows def perform_predictions(texts, tokenizer, model, device, batch_size, top1): logging.info("Starting predictions on %d texts", len(texts)) predictions = predict( texts, tokenizer, model, device, batch_size=batch_size, return_all_scores=not top1, ) logging.info("Predictions completed") return predictions