ST-THOMAS-OF-AQUINAS commited on
Commit
4910b5a
·
verified ·
1 Parent(s): a3ff60f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -11,39 +11,37 @@ import os
11
  HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
12
 
13
  # -----------------------------
14
- # Load model from Hugging Face
15
  # -----------------------------
16
- model_id = "ST-THOMAS-OF-AQUINAS/SCAM"
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
19
  model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
20
  model.eval()
21
 
22
- label_map = {0: "MAXWELL KURIA", 1: "Keliv Kuria"}
23
-
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
26
 
27
  # -----------------------------
28
  # Helper function
29
  # -----------------------------
30
- def predict_author(text: str):
31
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
32
  inputs = {k: v.to(device) for k, v in inputs.items()}
33
 
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
37
- pred = torch.argmax(probs, dim=1).item()
38
- confidence = probs[0][pred].item()
39
 
40
- predicted_author = label_map.get(pred, "unknown")
41
- return predicted_author, round(confidence * 100, 2)
 
42
 
43
  # -----------------------------
44
  # FastAPI app
45
  # -----------------------------
46
- app = FastAPI(title="Scam Detector API with Twilio")
47
 
48
  # Health-check route
49
  @app.get("/")
@@ -53,8 +51,8 @@ async def health_check():
53
  # Simple GET test
54
  @app.get("/predict")
55
  async def get_predict(text: str):
56
- author, confidence = predict_author(text)
57
- return {"prediction": author, "confidence": confidence}
58
 
59
  # -----------------------------
60
  # Twilio WhatsApp POST
@@ -64,8 +62,8 @@ async def whatsapp_reply(Body: str = Form(...)):
64
  resp = MessagingResponse()
65
 
66
  if Body.strip():
67
- author, confidence = predict_author(Body)
68
- reply = f"Prediction: {author}\nConfidence: {confidence}%"
69
  else:
70
  reply = "⚠️ No text detected."
71
 
 
11
  HF_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
12
 
13
  # -----------------------------
14
+ # Load regression model from Hugging Face
15
  # -----------------------------
16
+ model_id = "ST-THOMAS-OF-AQUINAS/impersonation-bart"
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
19
  model = AutoModelForSequenceClassification.from_pretrained(model_id, cache_dir=HF_CACHE_DIR)
20
  model.eval()
21
 
 
 
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  model.to(device)
24
 
25
  # -----------------------------
26
  # Helper function
27
  # -----------------------------
28
+ def predict_score(text: str):
29
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
30
  inputs = {k: v.to(device) for k, v in inputs.items()}
31
 
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
+ # For regression, logits is shape [batch, 1]
35
+ score = outputs.logits.squeeze().item()
 
36
 
37
+ # Clamp between 0 and 1 (just in case)
38
+ score = max(0.0, min(1.0, score))
39
+ return round(score, 3)
40
 
41
  # -----------------------------
42
  # FastAPI app
43
  # -----------------------------
44
+ app = FastAPI(title="Impersonation Detector API with Twilio")
45
 
46
  # Health-check route
47
  @app.get("/")
 
51
  # Simple GET test
52
  @app.get("/predict")
53
  async def get_predict(text: str):
54
+ score = predict_score(text)
55
+ return {"impersonation_score": score}
56
 
57
  # -----------------------------
58
  # Twilio WhatsApp POST
 
62
  resp = MessagingResponse()
63
 
64
  if Body.strip():
65
+ score = predict_score(Body)
66
+ reply = f"Impersonation Score: {score}\n(0.0 = genuine, 1.0 = impersonation)"
67
  else:
68
  reply = "⚠️ No text detected."
69