cjell commited on
Commit
8a8fd9e
·
1 Parent(s): 94d05dd

fixingdgin

Browse files
Files changed (2) hide show
  1. app.py +13 -5
  2. test.py +11 -3
app.py CHANGED
@@ -1,19 +1,27 @@
1
  import os
2
- # Use HF_HOME instead of TRANSFORMERS_CACHE, and point it to /tmp
3
- os.environ["HF_HOME"] = "/tmp"
4
 
5
  from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  from transformers import pipeline
8
 
 
9
  classifier = pipeline("text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection")
10
 
 
 
 
 
 
 
 
 
 
11
  class Query(BaseModel):
12
  text: str
13
 
14
- app = FastAPI()
15
-
16
- @app.post("/predict")
17
  def predict(query: Query):
18
  result = classifier(query.text)[0]
19
  return {"label": result["label"], "score": result["score"]}
 
1
  import os
2
+ os.environ["HF_HOME"] = "/tmp" # ensure Hugging Face cache is writable
 
3
 
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from transformers import pipeline
7
 
8
+ # Load the Hugging Face model
9
  classifier = pipeline("text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection")
10
 
11
+ # Define FastAPI app
12
+ app = FastAPI()
13
+
14
+ # Health check route
15
+ @app.get("/")
16
+ def read_root():
17
+ return {"status": "ok", "message": "API is running"}
18
+
19
+ # Define request schema
20
  class Query(BaseModel):
21
  text: str
22
 
23
+ # Prediction route (POST /)
24
+ @app.post("/")
 
25
  def predict(query: Query):
26
  result = classifier(query.text)[0]
27
  return {"label": result["label"], "score": result["score"]}
test.py CHANGED
@@ -1,9 +1,17 @@
1
  import requests
2
 
 
 
 
 
 
 
 
 
3
  resp = requests.post(
4
- "https://spam-fastapi-cjell.hf.space/predict",
5
  json={"text": "Congratulations! You've won a free cruise!"}
6
  )
7
 
8
- print("Status:", resp.status_code)
9
- print("Raw text:", resp.text)
 
1
  import requests
2
 
3
+ # Hugging Face Space base URL
4
+ BASE_URL = "https://spam-fastapi-cjell.hf.space"
5
+
6
+ # First check if API is live
7
+ resp = requests.get(BASE_URL)
8
+ print("Health check:", resp.status_code, resp.json())
9
+
10
+ # Now send prediction request
11
  resp = requests.post(
12
+ BASE_URL, # POST / (root) handles predictions
13
  json={"text": "Congratulations! You've won a free cruise!"}
14
  )
15
 
16
+ print("Prediction status:", resp.status_code)
17
+ print("Prediction response:", resp.json())