mjpsm commited on
Commit
950cef0
·
verified ·
1 Parent(s): a69a96f

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +10 -0
  2. app.py +64 -0
  3. requirements.txt +4 -0
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # 🔥 Expose port (Spaces uses 7860)
10
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+
6
+ app = FastAPI(
7
+ title="Check-in Detail Classifier API",
8
+ description="Classifies check-ins as DETAILED or NOT_DETAILED",
9
+ version="1.0"
10
+ )
11
+
12
+ # Load model once (efficient)
13
+ MODEL_NAME = "mjpsm/checkin-detail-classifier"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
17
+
18
+ model.eval()
19
+
20
+ # Request schema
21
+ class Request(BaseModel):
22
+ text: str
23
+
24
+ # Root route
25
+ @app.get("/")
26
+ def root():
27
+ return {
28
+ "message": "Welcome to the Check-in Detail Classifier API"
29
+ }
30
+
31
+ # Classification logic
32
+ def classify(text: str):
33
+ inputs = tokenizer(
34
+ text,
35
+ return_tensors="pt",
36
+ truncation=True,
37
+ padding=True
38
+ )
39
+
40
+ # Remove token_type_ids (DistilBERT fix)
41
+ inputs.pop("token_type_ids", None)
42
+
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+
46
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
47
+
48
+ pred = torch.argmax(probs).item()
49
+ confidence = probs[0][pred].item()
50
+
51
+ label = model.config.id2label[pred]
52
+
53
+ return label, confidence
54
+
55
+ # Predict endpoint
56
+ @app.post("/predict")
57
+ def predict(req: Request):
58
+ label, confidence = classify(req.text)
59
+
60
+ return {
61
+ "input": req.text,
62
+ "prediction": label,
63
+ "confidence": round(confidence, 4)
64
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torch