cjell commited on
Commit
10f7d04
·
1 Parent(s): d8e3053

fixing model output formats

Browse files
Files changed (3) hide show
  1. app.py +67 -24
  2. test_health.py +9 -3
  3. test_spam.py +5 -1
app.py CHANGED
@@ -1,57 +1,100 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline
 
4
  import os
5
 
6
-
7
  os.environ["HF_HOME"] = "/tmp"
8
 
9
  SPAM_MODEL = "valurank/distilroberta-spam-comments-detection"
10
  TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier"
11
- SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
12
  NSFW_MODEL = "michellejieli/NSFW_text_classifier"
13
 
 
14
  spam = pipeline("text-classification", model=SPAM_MODEL)
15
-
16
  toxic = pipeline("text-classification", model=TOXIC_MODEL)
 
 
17
 
18
- sentiment = pipeline("text-classification", model = SENTIMENT_MODEL)
19
-
20
- nsfw = pipeline("text-classification", model = NSFW_MODEL)
21
-
22
-
23
- app = FastAPI()
24
-
25
- @app.get("/")
26
- def root():
27
- return {"status": "ok"}
28
 
29
  class Query(BaseModel):
30
  text: str
31
 
32
- @app.post("/spam")
33
- def predict_spam(query: Query):
34
- result = spam(query.text)[0]
35
- return {"label": result["label"], "score": result["score"]}
 
 
 
 
 
 
 
 
 
36
 
37
- @app.post("/toxic")
38
- def predict_toxic(query: Query):
 
39
  result = toxic(query.text)[0]
40
- return {"label": result["label"], "score": result["score"]}
 
 
 
 
 
 
 
 
41
 
42
  @app.post("/sentiment")
43
  def predict_sentiment(query: Query):
44
  result = sentiment(query.text)[0]
45
- return {"label": result["label"], "score": result["score"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  @app.post("/nsfw")
48
  def predict_nsfw(query: Query):
49
  result = nsfw(query.text)[0]
50
- return {"label": result["label"], "score": result["score"]}
 
 
 
 
 
 
51
 
 
52
  @app.get("/health")
53
  def health_check():
54
-
55
  status = {
56
  "server": "running",
57
  "models": {}
@@ -77,4 +120,4 @@ def health_check():
77
  "status": f"error: {str(e)}"
78
  }
79
 
80
- return status
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
+ from datetime import datetime
5
  import os
6
 
 
7
  os.environ["HF_HOME"] = "/tmp"
8
 
9
  SPAM_MODEL = "valurank/distilroberta-spam-comments-detection"
10
  TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier"
11
+ SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
12
  NSFW_MODEL = "michellejieli/NSFW_text_classifier"
13
 
14
+ # Load models
15
  spam = pipeline("text-classification", model=SPAM_MODEL)
 
16
  toxic = pipeline("text-classification", model=TOXIC_MODEL)
17
+ sentiment = pipeline("text-classification", model=SENTIMENT_MODEL)
18
+ nsfw = pipeline("text-classification", model=NSFW_MODEL)
19
 
20
+ app = FastAPI(title="Plebzs AI Models API")
 
 
 
 
 
 
 
 
 
21
 
22
  class Query(BaseModel):
23
  text: str
24
 
25
+ @app.get("/")
26
+ def root():
27
+ return {"status": "ok", "message": "Plebzs AI Models API"}
28
+
29
+ # Required by Plebzs boss
30
+ @app.get("/moderation/ping")
31
+ def moderation_ping():
32
+ return {
33
+ "status": "healthy",
34
+ "models": ["spam", "toxic", "sentiment", "nsfw"],
35
+ "timestamp": datetime.now().isoformat(),
36
+ "version": "1.0.0"
37
+ }
38
 
39
+ # Main endpoints - formatted for Plebzs
40
+ @app.post("/toxicity") # Changed name to match Plebzs expectation
41
+ def predict_toxicity(query: Query):
42
  result = toxic(query.text)[0]
43
+
44
+ # Convert to 0-1 toxicity scale
45
+ toxicity_score = result["score"] if result["label"] == "TOXIC" else 1 - result["score"]
46
+
47
+ return {
48
+ "toxicity_score": round(toxicity_score, 3),
49
+ "confidence": round(result["score"], 3),
50
+ "raw_output": result
51
+ }
52
 
53
  @app.post("/sentiment")
54
  def predict_sentiment(query: Query):
55
  result = sentiment(query.text)[0]
56
+
57
+ # Convert star rating to -1 to 1 scale
58
+ label = result["label"]
59
+ if "1" in label or "2" in label: # 1-2 stars = negative
60
+ sentiment_score = -0.7
61
+ elif "3" in label: # 3 stars = neutral
62
+ sentiment_score = 0.0
63
+ else: # 4-5 stars = positive
64
+ sentiment_score = 0.7
65
+
66
+ return {
67
+ "sentiment_score": round(sentiment_score, 3),
68
+ "confidence": round(result["score"], 3),
69
+ "raw_output": result
70
+ }
71
+
72
+ # Bonus endpoints (not used by Plebzs yet, but good to have)
73
+ @app.post("/spam")
74
+ def predict_spam(query: Query):
75
+ result = spam(query.text)[0]
76
+ spam_score = result["score"] if result["label"] == "SPAM" else 1 - result["score"]
77
+
78
+ return {
79
+ "spam_score": round(spam_score, 3),
80
+ "confidence": round(result["score"], 3),
81
+ "raw_output": result
82
+ }
83
 
84
  @app.post("/nsfw")
85
  def predict_nsfw(query: Query):
86
  result = nsfw(query.text)[0]
87
+ nsfw_score = result["score"] if result["label"] == "NSFW" else 1 - result["score"]
88
+
89
+ return {
90
+ "nsfw_score": round(nsfw_score, 3),
91
+ "confidence": round(result["score"], 3),
92
+ "raw_output": result
93
+ }
94
 
95
+ # Keep your detailed health check
96
  @app.get("/health")
97
  def health_check():
 
98
  status = {
99
  "server": "running",
100
  "models": {}
 
120
  "status": f"error: {str(e)}"
121
  }
122
 
123
+ return status
test_health.py CHANGED
@@ -2,6 +2,12 @@ import requests
2
 
3
  url = "https://cjell-Demo.hf.space"
4
 
5
- if __name__ == "__main__":
6
- response = requests.get(f"{url}/health", timeout=10)
7
- print(response.json())
 
 
 
 
 
 
 
2
 
3
  url = "https://cjell-Demo.hf.space"
4
 
5
+ resp = requests.get(f"{url}/health", timeout=10)
6
+ data = resp.json()
7
+
8
+ print(f"\nServer status: {data['server'].upper()}")
9
+ print("Model statuses:")
10
+
11
+ for key, info in data["models"].items():
12
+ status = info["status"].upper()
13
+ print(f" - {key}: {info['model_name']} → {status}")
test_spam.py CHANGED
@@ -10,4 +10,8 @@ print("Status:", response.status_code)
10
  try:
11
  print("JSON:", response.json())
12
  except Exception:
13
- print("Raw text:", response.text)
 
 
 
 
 
10
  try:
11
  print("JSON:", response.json())
12
  except Exception:
13
+ print("Raw text:", response.text)
14
+
15
+ print("")
16
+
17
+ print(response.text)