VoltIC commited on
Commit
a061c0a
Β·
verified Β·
1 Parent(s): e52093f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -26
app.py CHANGED
@@ -4,8 +4,7 @@ from fastapi.responses import HTMLResponse
4
  from pydantic import BaseModel
5
  import joblib, os, re, uvicorn
6
 
7
- # ---------------- App ----------------
8
- app = FastAPI(title="Hate Speech Detection API")
9
 
10
  app.add_middleware(
11
  CORSMiddleware,
@@ -14,24 +13,26 @@ app.add_middleware(
14
  allow_headers=["*"],
15
  )
16
 
17
- # ---------------- Load Local Model ----------------
18
  MODEL_PATH = "hate_speech_model.pkl"
19
  if not os.path.exists(MODEL_PATH):
20
  raise RuntimeError("hate_speech_model.pkl not found")
21
 
22
  model = joblib.load(MODEL_PATH)
23
 
24
- # ---------------- Schema ----------------
 
 
 
 
 
25
  class TextRequest(BaseModel):
26
  text: str
27
 
28
- # ---------------- Utils ----------------
29
  def clean_text(text: str) -> str:
30
  text = re.sub(r"http\S+", " URL ", text)
31
  text = re.sub(r"@\w+", " USER ", text)
32
  return text.lower().strip()
33
 
34
- # ---------------- UI ----------------
35
  @app.get("/", response_class=HTMLResponse)
36
  def home():
37
  return """
@@ -39,11 +40,17 @@ def home():
39
  <head><title>Hate Speech Detector</title></head>
40
  <body style="font-family:Arial">
41
  <h2>πŸ›‘ Hate Speech Detection</h2>
42
- <form onsubmit="event.preventDefault(); analyze();">
43
- <textarea id="text" rows="5" cols="60"
44
- placeholder="Enter text..."></textarea><br><br>
45
- <button type="submit">Analyze</button>
46
- </form>
 
 
 
 
 
 
47
  <p id="result"></p>
48
 
49
  <script>
@@ -56,46 +63,42 @@ def home():
56
  });
57
  const data = await res.json();
58
  document.getElementById("result").innerText =
59
- data.class_name || data.detail;
60
  }
61
  </script>
62
  </body>
63
  </html>
64
  """
65
 
66
- # ---------------- Health ----------------
67
  @app.get("/health")
68
  def health():
69
  return {"status": "ok", "model_loaded": True}
70
 
71
- # ---------------- API ----------------
72
  @app.post("/analyze")
73
  def analyze(req: TextRequest):
74
  if len(req.text.strip()) < 10:
75
- raise HTTPException(400, "Text too short")
76
 
77
  cleaned = clean_text(req.text)
78
- pred = int(model.predict([cleaned])[0])
79
 
80
- # Handle Pipeline or Classifier
81
- classes = (
82
- model.classes_.tolist()
83
- if hasattr(model, "classes_")
84
- else model.named_steps["classifier"].classes_.tolist()
85
- )
86
 
87
  result = {
88
- "predicted_class": pred,
89
- "class_name": classes[pred]
90
  }
91
 
92
  if hasattr(model, "predict_proba"):
93
  probs = model.predict_proba([cleaned])[0]
94
- result["confidence"] = round(float(probs[pred]) * 100, 2)
95
 
96
  return result
97
 
98
- # ---------------- Runner ----------------
99
  if __name__ == "__main__":
100
  uvicorn.run(
101
  app,
 
4
  from pydantic import BaseModel
5
  import joblib, os, re, uvicorn
6
 
7
+ app = FastAPI(title="Hate Speech Detection")
 
8
 
9
  app.add_middleware(
10
  CORSMiddleware,
 
13
  allow_headers=["*"],
14
  )
15
 
 
16
  MODEL_PATH = "hate_speech_model.pkl"
17
  if not os.path.exists(MODEL_PATH):
18
  raise RuntimeError("hate_speech_model.pkl not found")
19
 
20
  model = joblib.load(MODEL_PATH)
21
 
22
+ CLASS_MAPPING = {
23
+ 0: "Hate Speech",
24
+ 1: "Offensive Language",
25
+ 2: "Neither"
26
+ }
27
+
28
  class TextRequest(BaseModel):
29
  text: str
30
 
 
31
  def clean_text(text: str) -> str:
32
  text = re.sub(r"http\S+", " URL ", text)
33
  text = re.sub(r"@\w+", " USER ", text)
34
  return text.lower().strip()
35
 
 
36
  @app.get("/", response_class=HTMLResponse)
37
  def home():
38
  return """
 
40
  <head><title>Hate Speech Detector</title></head>
41
  <body style="font-family:Arial">
42
  <h2>πŸ›‘ Hate Speech Detection</h2>
43
+ <p>Class Mapping:</p>
44
+ <ul>
45
+ <li>0 β†’ Hate Speech</li>
46
+ <li>1 β†’ Offensive Language</li>
47
+ <li>2 β†’ Neither</li>
48
+ </ul>
49
+
50
+ <textarea id="text" rows="5" cols="60"
51
+ placeholder="Enter text..."></textarea><br><br>
52
+ <button onclick="analyze()">Analyze</button>
53
+
54
  <p id="result"></p>
55
 
56
  <script>
 
63
  });
64
  const data = await res.json();
65
  document.getElementById("result").innerText =
66
+ "Output: " + data.predicted_class + " (" + data.class_name + ")";
67
  }
68
  </script>
69
  </body>
70
  </html>
71
  """
72
 
 
73
  @app.get("/health")
74
  def health():
75
  return {"status": "ok", "model_loaded": True}
76
 
 
77
  @app.post("/analyze")
78
  def analyze(req: TextRequest):
79
  if len(req.text.strip()) < 10:
80
+ raise HTTPException(400, "Text must be at least 10 characters long")
81
 
82
  cleaned = clean_text(req.text)
 
83
 
84
+ try:
85
+ prediction = int(model.predict([cleaned])[0])
86
+ except Exception as e:
87
+ raise HTTPException(500, str(e))
88
+
89
+ class_name = CLASS_MAPPING.get(prediction, "Unknown")
90
 
91
  result = {
92
+ "predicted_class": prediction,
93
+ "class_name": class_name
94
  }
95
 
96
  if hasattr(model, "predict_proba"):
97
  probs = model.predict_proba([cleaned])[0]
98
+ result["confidence"] = round(float(probs[prediction]) * 100, 2)
99
 
100
  return result
101
 
 
102
  if __name__ == "__main__":
103
  uvicorn.run(
104
  app,