| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, BertForSequenceClassification |
| from pathlib import Path |
|
|
| MODEL_PATH = Path(__file__).parent.parent / "model" |
|
|
| SDG_METADATA = { |
| "SDG 1": "No Poverty", |
| "SDG 2": "Zero Hunger", |
| "SDG 3": "Good Health and Well-being", |
| "SDG 4": "Quality Education", |
| "SDG 5": "Gender Equality", |
| "SDG 6": "Clean Water and Sanitation", |
| "SDG 7": "Affordable and Clean Energy", |
| "SDG 8": "Decent Work and Economic Growth", |
| "SDG 9": "Industry, Innovation and Infrastructure", |
| "SDG 10": "Reduced Inequalities", |
| "SDG 11": "Sustainable Cities and Communities", |
| "SDG 12": "Responsible Consumption and Production", |
| "SDG 13": "Climate Action", |
| "SDG 14": "Life Below Water", |
| "SDG 15": "Life on Land", |
| "SDG 16": "Peace, Justice and Strong Institutions", |
| "SDG 17": "Partnerships for the Goals", |
| } |
|
|
| class SDGClassifier: |
| def __init__(self): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Loading model on {self.device}...") |
| self.tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH)) |
| self.model = BertForSequenceClassification.from_pretrained(str(MODEL_PATH)) |
| self.model.to(self.device) |
| self.model.eval() |
| print("Model loaded successfully!") |
|
|
| def predict(self, text: str, top_k: int = 3) -> list: |
| inputs = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=128, |
| padding=True |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| logits = self.model(**inputs).logits |
|
|
| probs = F.softmax(logits, dim=-1).squeeze() |
| top = probs.topk(top_k) |
|
|
| results = [] |
| for i, idx in enumerate(top.indices): |
| sdg_key = f"SDG {idx.item() + 1}" |
| results.append({ |
| "sdg": sdg_key, |
| "name": SDG_METADATA[sdg_key], |
| "confidence": round(top.values[i].item() * 100, 2) |
| }) |
|
|
| return results |
|
|
| |
| classifier = SDGClassifier() |