MakPr016 commited on
Commit
63f5626
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ model/*.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ sdg/
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ gcc \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY app/ ./app/
13
+ COPY model/ ./model/
14
+ COPY run.py .
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["python", "run.py"]
app/__init__.py ADDED
File without changes
app/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
app/__pycache__/limiter.cpython-310.pyc ADDED
Binary file (275 Bytes). View file
 
app/__pycache__/main.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
app/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.66 kB). View file
 
app/limiter.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from slowapi import Limiter
2
+ from slowapi.util import get_remote_address
3
+
4
+ # Rate limiter — keyed by IP address
5
+ limiter = Limiter(key_func=get_remote_address)
app/main.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from slowapi import _rate_limit_exceeded_handler
5
+ from slowapi.errors import RateLimitExceeded
6
+ from pydantic import BaseModel, field_validator
7
+ from app.limiter import limiter
8
+ from app.model import classifier
9
+ import time
10
+
11
+ app = FastAPI(
12
+ title="SDG Classifier API",
13
+ description="Classifies text into UN Sustainable Development Goals",
14
+ version="1.0.0"
15
+ )
16
+
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ app.state.limiter = limiter
25
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
26
+
27
+
28
+ class ClassifyRequest(BaseModel):
29
+ text: str
30
+ top_k: int = 3
31
+
32
+ @field_validator("text")
33
+ @classmethod
34
+ def text_must_not_be_empty(cls, v):
35
+ if not v.strip():
36
+ raise ValueError("text must not be empty")
37
+ if len(v) > 2000:
38
+ raise ValueError("text must be under 2000 characters")
39
+ return v.strip()
40
+
41
+ @field_validator("top_k")
42
+ @classmethod
43
+ def top_k_must_be_valid(cls, v):
44
+ if not 1 <= v <= 5:
45
+ raise ValueError("top_k must be between 1 and 5")
46
+ return v
47
+
48
+
49
+ class SDGResult(BaseModel):
50
+ sdg: str
51
+ name: str
52
+ confidence: float
53
+
54
+
55
+ class ClassifyResponse(BaseModel):
56
+ text: str
57
+ predictions: list[SDGResult]
58
+ latency_ms: float
59
+ warning: str | None = None
60
+
61
+
62
+ @app.get("/")
63
+ def root():
64
+ return {"status": "ok", "message": "SDG Classifier API is running"}
65
+
66
+
67
+ @app.get("/health")
68
+ def health():
69
+ return {"status": "healthy"}
70
+
71
+
72
+ @app.post("/classify", response_model=ClassifyResponse, summary="Classify text into SDGs")
73
+ @limiter.limit("20/minute")
74
+ async def classify(request: Request, body: ClassifyRequest):
75
+ start = time.time()
76
+
77
+ try:
78
+ predictions = classifier.predict(body.text, body.top_k)
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
81
+
82
+ latency = round((time.time() - start) * 1000, 2)
83
+
84
+ warning = None
85
+ if predictions[0]["confidence"] > 85 and predictions[1]["confidence"] < 5:
86
+ warning = "Low prediction diversity — input may not be SDG-related text."
87
+
88
+ return ClassifyResponse(
89
+ text=body.text,
90
+ predictions=predictions,
91
+ latency_ms=latency,
92
+ warning=warning
93
+ )
app/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, BertForSequenceClassification
4
+ from pathlib import Path
5
+
6
+ MODEL_PATH = Path(__file__).parent.parent / "model"
7
+
8
+ SDG_METADATA = {
9
+ "SDG 1": "No Poverty",
10
+ "SDG 2": "Zero Hunger",
11
+ "SDG 3": "Good Health and Well-being",
12
+ "SDG 4": "Quality Education",
13
+ "SDG 5": "Gender Equality",
14
+ "SDG 6": "Clean Water and Sanitation",
15
+ "SDG 7": "Affordable and Clean Energy",
16
+ "SDG 8": "Decent Work and Economic Growth",
17
+ "SDG 9": "Industry, Innovation and Infrastructure",
18
+ "SDG 10": "Reduced Inequalities",
19
+ "SDG 11": "Sustainable Cities and Communities",
20
+ "SDG 12": "Responsible Consumption and Production",
21
+ "SDG 13": "Climate Action",
22
+ "SDG 14": "Life Below Water",
23
+ "SDG 15": "Life on Land",
24
+ "SDG 16": "Peace, Justice and Strong Institutions",
25
+ "SDG 17": "Partnerships for the Goals",
26
+ }
27
+
28
+ class SDGClassifier:
29
+ def __init__(self):
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ print(f"Loading model on {self.device}...")
32
+ self.tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
33
+ self.model = BertForSequenceClassification.from_pretrained(str(MODEL_PATH))
34
+ self.model.to(self.device)
35
+ self.model.eval()
36
+ print("Model loaded successfully!")
37
+
38
+ def predict(self, text: str, top_k: int = 3) -> list:
39
+ inputs = self.tokenizer(
40
+ text,
41
+ return_tensors="pt",
42
+ truncation=True,
43
+ max_length=128,
44
+ padding=True
45
+ )
46
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
47
+
48
+ with torch.no_grad():
49
+ logits = self.model(**inputs).logits
50
+
51
+ probs = F.softmax(logits, dim=-1).squeeze()
52
+ top = probs.topk(top_k)
53
+
54
+ results = []
55
+ for i, idx in enumerate(top.indices):
56
+ sdg_key = f"SDG {idx.item() + 1}"
57
+ results.append({
58
+ "sdg": sdg_key,
59
+ "name": SDG_METADATA[sdg_key],
60
+ "confidence": round(top.values[i].item() * 100, 2)
61
+ })
62
+
63
+ return results
64
+
65
+ # Singleton — loaded once when the app starts
66
+ classifier = SDGClassifier()
model/config.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_cross_attention": false,
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": null,
8
+ "classifier_dropout": null,
9
+ "dtype": "float32",
10
+ "eos_token_id": null,
11
+ "gradient_checkpointing": false,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "id2label": {
16
+ "0": "SDG 1",
17
+ "1": "SDG 2",
18
+ "2": "SDG 3",
19
+ "3": "SDG 4",
20
+ "4": "SDG 5",
21
+ "5": "SDG 6",
22
+ "6": "SDG 7",
23
+ "7": "SDG 8",
24
+ "8": "SDG 9",
25
+ "9": "SDG 10",
26
+ "10": "SDG 11",
27
+ "11": "SDG 12",
28
+ "12": "SDG 13",
29
+ "13": "SDG 14",
30
+ "14": "SDG 15",
31
+ "15": "SDG 16",
32
+ "16": "SDG 17"
33
+ },
34
+ "initializer_range": 0.02,
35
+ "intermediate_size": 3072,
36
+ "is_decoder": false,
37
+ "label2id": {
38
+ "SDG 1": 0,
39
+ "SDG 10": 9,
40
+ "SDG 11": 10,
41
+ "SDG 12": 11,
42
+ "SDG 13": 12,
43
+ "SDG 14": 13,
44
+ "SDG 15": 14,
45
+ "SDG 16": 15,
46
+ "SDG 17": 16,
47
+ "SDG 2": 1,
48
+ "SDG 3": 2,
49
+ "SDG 4": 3,
50
+ "SDG 5": 4,
51
+ "SDG 6": 5,
52
+ "SDG 7": 6,
53
+ "SDG 8": 7,
54
+ "SDG 9": 8
55
+ },
56
+ "layer_norm_eps": 1e-12,
57
+ "max_position_embeddings": 512,
58
+ "model_type": "bert",
59
+ "num_attention_heads": 12,
60
+ "num_hidden_layers": 12,
61
+ "pad_token_id": 0,
62
+ "position_embedding_type": "absolute",
63
+ "problem_type": "single_label_classification",
64
+ "tie_word_embeddings": true,
65
+ "transformers_version": "5.0.0",
66
+ "type_vocab_size": 2,
67
+ "use_cache": false,
68
+ "vocab_size": 30522
69
+ }
model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65bfdd2b0083217dee9ebd9861cea316d212c88c0579a20aef56906a323948a9
3
+ size 438004764
model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 512,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "BertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn==0.30.0
3
+ transformers==4.47.0
4
+ torch==2.5.1
5
+ slowapi==0.1.9
6
+ python-dotenv==1.0.0
run.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+
3
+ if __name__ == "__main__":
4
+ uvicorn.run(
5
+ "app.main:app",
6
+ host="0.0.0.0",
7
+ port=7860,
8
+ reload=False
9
+ )