Simon Clematide
Refactor CLI prediction script to enhance argument parsing and modularize inference logic. Add excel generation
f9c9b95
# sdg_predict/inference.py
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
import logging
import json
def load_model(model_name, device):
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
model.eval()
return tokenizer, model
def batched(iterable, batch_size):
for i in range(0, len(iterable), batch_size):
yield iterable[i : i + batch_size]
def predict(texts, tokenizer, model, device, batch_size=8, return_all_scores=True):
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=device,
batch_size=batch_size,
truncation=True,
padding=True,
max_length=512,
top_k=None if return_all_scores else 1,
)
results = classifier(texts)
if return_all_scores:
for result in results:
for score in result:
score["score"] = round(
score["score"], 3
) # Round scores to 3 decimal places
else:
for result in results:
result["score"] = round(
result["score"], 3
) # Round top score to 3 decimal places
return results
def binary_from_softmax(prediction, cap_class0=0.5):
score_0 = next((x["score"] for x in prediction if x["label"] == "0"), 0.0)
score_0 = min(score_0, cap_class0)
binary_predictions = {
label: 0.0 for label in map(str, range(1, 18))
} # Initialize all labels to 0.0
for entry in prediction:
label = entry["label"]
if label == "0":
continue
score = entry["score"]
binary_score = score / (score + score_0) if (score + score_0) > 0 else 0.0
binary_predictions[label] = round(binary_score, 3)
return binary_predictions
def setup_device():
logging.info("Setting up device")
if torch.backends.mps.is_available():
logging.info("Using MPS device")
return torch.device("mps")
elif torch.cuda.is_available():
logging.info("Using CUDA device")
return torch.device("cuda")
else:
logging.info("Using CPU device")
return torch.device("cpu")
def load_model_and_tokenizer(model_name, device):
logging.info("Loading model: %s", model_name)
tokenizer, model = load_model(model_name, device)
logging.info("Model loaded successfully")
return tokenizer, model
def load_input_data(input, key):
logging.info("Loading input data from %s", input)
texts = []
rows = []
with input.open() as f:
for line in f:
row = json.loads(line)
if key not in row:
continue
texts.append(row[key])
logging.debug("Text: %s", row[key])
rows.append(row)
logging.info("Loaded %d rows of input data", len(rows))
return texts, rows
def perform_predictions(texts, tokenizer, model, device, batch_size, top1):
logging.info("Starting predictions on %d texts", len(texts))
predictions = predict(
texts,
tokenizer,
model,
device,
batch_size=batch_size,
return_all_scores=not top1,
)
logging.info("Predictions completed")
return predictions