File size: 4,974 Bytes
0e2fe46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import time
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from datetime import datetime
from fastapi.responses import FileResponse
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import fasttext
import csv
app = FastAPI()
# Allowed SMPP, SMSC or any External IP addresses
ALLOWED_IPS = {"127.0.0.1", "localhost", "10.0.0.1"}
# Add CORSMiddleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load BERT model
bert_model_path = "../BERT/training/bert_sms_spam_phishing_model"
bert_model = BertForSequenceClassification.from_pretrained(bert_model_path)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model.eval()
# Load FastText model
fasttext_model_path = "../FastText/training/ots_sms_model_v1.1.bin"
fasttext_model = fasttext.load_model(fasttext_model_path)
class SMS(BaseModel):
text: str
model: str # "bert" or "fasttext"
class Feedback(BaseModel):
content: str
feedback: str
thumbs_up: bool
thumbs_down: bool
user_id: Optional[str] = None
model: str # "bert" or "fasttext"
def preprocess_text(text, tokenizer, max_len=128):
return tokenizer.encode_plus(
text, add_special_tokens=True, max_length=max_len,
padding='max_length', return_attention_mask=True,
return_tensors='pt', truncation=True
)
def write_feedback(feedback_data, model_name):
file_name = f"feedback_{model_name}.csv"
with open(file_name, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
if file.tell() == 0:
writer.writerow(["Timestamp", "UserID", "Content", "Feedback", "Thumbs Up", "Thumbs Down"])
writer.writerow(feedback_data)
def verify_ip_address(request: Request):
client_host = request.client.host
if client_host not in ALLOWED_IPS:
raise HTTPException(status_code=403, detail="Access denied")
return client_host
# Route to predict SMS using specified model - supported "bert" , "fasttext"
@app.post("/predict/", dependencies=[Depends(verify_ip_address)])
async def predict_sms(sms: SMS):
start_time = time.time()
if not sms.text:
raise HTTPException(status_code=400, detail="Text is empty")
if sms.model == "bert":
inputs = preprocess_text(sms.text, bert_tokenizer)
with torch.no_grad():
outputs = bert_model(**inputs)
prediction = torch.argmax(outputs.logits, dim=1).item()
label_map = {0: 'ham', 1: 'spam', 2: 'phishing'}
label = label_map[prediction]
probability = torch.nn.functional.softmax(outputs.logits, dim=1).max().item()
model_info = {"Model_Name": "OTS_bert", "Model_Version": "1.1.4"}
elif sms.model == "fasttext":
label, probability = fasttext_model.predict(sms.text, k=1) # Ensure k=1 for single label prediction
label = label[0].replace('__label__', '')
probability = probability[0] # Extract the probability value
model_info = {
"Model_Name": "OTS_fasttext",
"Model_Version": "1.1.4",
"Model_Author": "TelecomsXChange (TCXC)",
"Last_Training": "2023-12-21"
}
else:
raise HTTPException(status_code=400, detail="Invalid model type")
end_time = time.time()
return {
"label": label,
"probability": probability,
"processing_time": end_time - start_time,
**model_info,
"Model_Author": "TelecomsXChange (TCXC)",
"Last_Training": "2023-12-21" # Update accordingly
}
# Feedback loop and download feedback
@app.post("/feedback-loop/", dependencies=[Depends(verify_ip_address)])
async def feedback_loop(feedback: Feedback):
thumbs_up = 'Yes' if feedback.thumbs_up else 'No'
thumbs_down = 'Yes' if feedback.thumbs_down else 'No'
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
feedback_data = [timestamp, feedback.user_id, feedback.content, feedback.feedback, thumbs_up, thumbs_down]
if feedback.model in ["bert", "fasttext"]:
write_feedback(feedback_data, feedback.model)
else:
raise HTTPException(status_code=400, detail="Invalid model type")
return {"message": "Feedback received"}
@app.get("/download-feedback/{model_name}", dependencies=[Depends(verify_ip_address)])
async def download_feedback(model_name: str):
if model_name in ["bert", "fasttext"]:
file_path = f"feedback_{model_name}.csv"
else:
raise HTTPException(status_code=400, detail="Invalid model name")
return FileResponse(file_path, media_type='text/csv', filename=file_path)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
|