Yousuf-Islam commited on
Commit
05b38bc
·
verified ·
1 Parent(s): c199aa3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+
7
+ app = FastAPI()
8
+
9
+ # Enable CORS (Allows your React Frontend to talk to this API)
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"],
13
+ allow_credentials=True,
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ # Load Model (Global Variable)
19
+ MODEL_PATH = "/code/model"
20
+ print("Loading AI Model...")
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
22
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
23
+
24
+ class InputData(BaseModel):
25
+ sentence: str
26
+
27
+ @app.get("/")
28
+ def home():
29
+ return {"status": "Online", "model": "BanglaBERT"}
30
+
31
+ @app.post("/api/predict")
32
+ def predict(data: InputData):
33
+ try:
34
+ # Tokenize
35
+ inputs = tokenizer(data.sentence, return_tensors="pt", padding=True, truncation=True, max_length=64)
36
+
37
+ # Predict
38
+ with torch.no_grad():
39
+ logits = model(**inputs).logits
40
+
41
+ # Calculate Confidence
42
+ probs = torch.nn.functional.softmax(logits, dim=1)
43
+ conf = torch.max(probs).item()
44
+ pred_id = torch.argmax(probs).item()
45
+
46
+ # Label Mapping (1=Shirk, 0=Not Shirk)
47
+ label = "shirk" if pred_id == 1 else "not shirk"
48
+
49
+ return {
50
+ "result": label,
51
+ "confidence": f"{conf:.2%}",
52
+ "cleaned_sentence": data.sentence
53
+ }
54
+ except Exception as e:
55
+ return {"error": str(e)}