File size: 2,248 Bytes
63f5626 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, BertForSequenceClassification
from pathlib import Path
MODEL_PATH = Path(__file__).parent.parent / "model"
SDG_METADATA = {
"SDG 1": "No Poverty",
"SDG 2": "Zero Hunger",
"SDG 3": "Good Health and Well-being",
"SDG 4": "Quality Education",
"SDG 5": "Gender Equality",
"SDG 6": "Clean Water and Sanitation",
"SDG 7": "Affordable and Clean Energy",
"SDG 8": "Decent Work and Economic Growth",
"SDG 9": "Industry, Innovation and Infrastructure",
"SDG 10": "Reduced Inequalities",
"SDG 11": "Sustainable Cities and Communities",
"SDG 12": "Responsible Consumption and Production",
"SDG 13": "Climate Action",
"SDG 14": "Life Below Water",
"SDG 15": "Life on Land",
"SDG 16": "Peace, Justice and Strong Institutions",
"SDG 17": "Partnerships for the Goals",
}
class SDGClassifier:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading model on {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
self.model = BertForSequenceClassification.from_pretrained(str(MODEL_PATH))
self.model.to(self.device)
self.model.eval()
print("Model loaded successfully!")
def predict(self, text: str, top_k: int = 3) -> list:
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze()
top = probs.topk(top_k)
results = []
for i, idx in enumerate(top.indices):
sdg_key = f"SDG {idx.item() + 1}"
results.append({
"sdg": sdg_key,
"name": SDG_METADATA[sdg_key],
"confidence": round(top.values[i].item() * 100, 2)
})
return results
# Singleton — loaded once when the app starts
classifier = SDGClassifier() |