user6295018 commited on
Commit
a67f7df
·
verified ·
1 Parent(s): 50f0528

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -27
app.py CHANGED
@@ -3,47 +3,60 @@
3
  # ===============================================
4
 
5
  from fastapi import FastAPI, Request
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
- import torch
9
  import uvicorn
10
 
11
- # --- Initialize app ---
12
- app = FastAPI(title="Bloom Check-in Quality Classifier")
13
-
14
- # --- Allow web access from anywhere (CORS) ---
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
  )
22
 
23
- # --- Load model once on startup ---
24
- MODEL_NAME = "user6295018/checkin-quality-classifier"
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
27
- label_map = {0: "Vague", 1: "Neutral", 2: "Descriptive"}
28
 
29
  @app.get("/")
30
- async def root():
 
 
 
31
  return {"message": "Bloom Check-in Quality Classifier is running!"}
32
 
33
  @app.post("/api/predict")
34
  async def predict(request: Request):
35
- data = await request.json()
36
- text = data.get("text", "").strip()
 
 
 
 
 
 
 
 
 
37
  if not text:
38
  return {"error": "No text provided."}
39
 
40
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
41
- with torch.no_grad():
42
- outputs = model(**inputs)
43
- pred = torch.argmax(outputs.logits, dim=1).item()
 
 
 
 
 
 
 
 
 
 
44
 
45
- return {"label": label_map.get(pred, "Unknown")}
 
46
 
47
- # --- Explicit entry point so Spaces knows what to run ---
48
  if __name__ == "__main__":
49
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  # ===============================================
4
 
5
  from fastapi import FastAPI, Request
6
+ from transformers import pipeline
 
 
7
  import uvicorn
8
 
9
+ app = FastAPI(
10
+ title="Bloom Check-in Quality Classifier",
11
+ description="A FastAPI app that classifies check-ins as vague, neutral, or descriptive.",
12
+ version="1.0.0"
 
 
 
 
 
 
13
  )
14
 
15
+ # Load the Hugging Face text classification model
16
+ # This automatically downloads your public model from the Hub
17
+ classifier = pipeline("text-classification", model="user6295018/checkin-quality-classifier")
 
 
18
 
19
  @app.get("/")
20
+ def read_root():
21
+ """
22
+ Root endpoint – simple health check
23
+ """
24
  return {"message": "Bloom Check-in Quality Classifier is running!"}
25
 
26
  @app.post("/api/predict")
27
  async def predict(request: Request):
28
+ """
29
+ Predict endpoint – classify input text using the fine-tuned model
30
+ """
31
+ try:
32
+ data = await request.json()
33
+ except Exception:
34
+ return {"error": "Invalid JSON body."}
35
+
36
+ # Accept either {"text": "..."} or {"inputs": "..."}
37
+ text = data.get("text") or data.get("inputs")
38
+
39
  if not text:
40
  return {"error": "No text provided."}
41
 
42
+ try:
43
+ # Run model inference
44
+ result = classifier(text)[0]
45
+
46
+ # Optional: Map model output labels to human-friendly names
47
+ label_map = {
48
+ "LABEL_0": "vague",
49
+ "LABEL_1": "neutral",
50
+ "LABEL_2": "descriptive"
51
+ }
52
+ label = label_map.get(result["label"], result["label"])
53
+ score = round(float(result["score"]), 3)
54
+
55
+ return {"label": label, "score": score}
56
 
57
+ except Exception as e:
58
+ return {"error": f"Inference failed: {str(e)}"}
59
 
60
+ # Optional: local testing entry point (ignored on Spaces)
61
  if __name__ == "__main__":
62
  uvicorn.run(app, host="0.0.0.0", port=7860)