mjpsm commited on
Commit
132a058
·
verified ·
1 Parent(s): 9eb0b4d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ import torch
5
+
6
+ # Initialize FastAPI
7
+ app = FastAPI(title="Check-ins Classifier API", version="1.0")
8
+
9
+ # Load model and tokenizer
10
+ MODEL_NAME = "mjpsm/check-ins-classifier"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
13
+ model.eval()
14
+
15
+ # Define label mapping
16
+ id2label = {
17
+ 0: "Bad",
18
+ 1: "Mediocre",
19
+ 2: "Good"
20
+ }
21
+
22
+ # Input schema
23
+ class InputText(BaseModel):
24
+ text: str
25
+
26
+ @app.post("/predict")
27
+ async def predict(data: InputText):
28
+ # Tokenize input
29
+ inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True)
30
+
31
+ # Model inference
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
35
+ predicted_label_id = torch.argmax(probs, dim=-1).item()
36
+
37
+ # Return JSON response
38
+ return {
39
+ "input_text": data.text,
40
+ "predicted_label": id2label[predicted_label_id],
41
+ "label_id": predicted_label_id,
42
+ "probabilities": probs.tolist()
43
+ }
44
+
45
+ @app.get("/")
46
+ async def home():
47
+ return {"message": "Welcome to the Check-ins Classifier API. Use POST /predict to classify text."}