File size: 4,974 Bytes
0e2fe46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import time
from typing import Optional
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from datetime import datetime
from fastapi.responses import FileResponse
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import fasttext
import csv

app = FastAPI()

# Allowed SMPP, SMSC or any External IP addresses
ALLOWED_IPS = {"127.0.0.1", "localhost", "10.0.0.1"}

# Add CORSMiddleware to allow cross-origin requests
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# Load BERT model
bert_model_path = "../BERT/training/bert_sms_spam_phishing_model"
bert_model = BertForSequenceClassification.from_pretrained(bert_model_path)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model.eval()

# Load FastText model
fasttext_model_path = "../FastText/training/ots_sms_model_v1.1.bin"
fasttext_model = fasttext.load_model(fasttext_model_path)

class SMS(BaseModel):
    text: str
    model: str  # "bert" or "fasttext"

class Feedback(BaseModel):
    content: str
    feedback: str
    thumbs_up: bool
    thumbs_down: bool
    user_id: Optional[str] = None
    model: str  # "bert" or "fasttext"

def preprocess_text(text, tokenizer, max_len=128):
    return tokenizer.encode_plus(
        text, add_special_tokens=True, max_length=max_len,
        padding='max_length', return_attention_mask=True,
        return_tensors='pt', truncation=True
    )
def write_feedback(feedback_data, model_name):
    file_name = f"feedback_{model_name}.csv"
    with open(file_name, mode="a", newline="", encoding="utf-8") as file:
        writer = csv.writer(file)
        if file.tell() == 0:
            writer.writerow(["Timestamp", "UserID", "Content", "Feedback", "Thumbs Up", "Thumbs Down"])
        writer.writerow(feedback_data)
        
def verify_ip_address(request: Request):
    client_host = request.client.host
    if client_host not in ALLOWED_IPS:
        raise HTTPException(status_code=403, detail="Access denied")
    return client_host        


# Route to predict SMS using specified model - supported "bert" , "fasttext"

@app.post("/predict/", dependencies=[Depends(verify_ip_address)])
async def predict_sms(sms: SMS):
    start_time = time.time()

    if not sms.text:
        raise HTTPException(status_code=400, detail="Text is empty")

    if sms.model == "bert":
        inputs = preprocess_text(sms.text, bert_tokenizer)
        with torch.no_grad():
            outputs = bert_model(**inputs)
            prediction = torch.argmax(outputs.logits, dim=1).item()
        label_map = {0: 'ham', 1: 'spam', 2: 'phishing'}
        label = label_map[prediction]
        probability = torch.nn.functional.softmax(outputs.logits, dim=1).max().item()
        model_info = {"Model_Name": "OTS_bert", "Model_Version": "1.1.4"}
    elif sms.model == "fasttext":
        label, probability = fasttext_model.predict(sms.text, k=1)  # Ensure k=1 for single label prediction
        label = label[0].replace('__label__', '')
        probability = probability[0]  # Extract the probability value
        model_info = {
            "Model_Name": "OTS_fasttext",
            "Model_Version": "1.1.4",
            "Model_Author": "TelecomsXChange (TCXC)",
            "Last_Training": "2023-12-21"
        }
    else:
        raise HTTPException(status_code=400, detail="Invalid model type")

    end_time = time.time()
    return {
        "label": label,
        "probability": probability,
        "processing_time": end_time - start_time,
        **model_info,
        "Model_Author": "TelecomsXChange (TCXC)",
        "Last_Training": "2023-12-21"  # Update accordingly
    }

# Feedback loop and download feedback 

@app.post("/feedback-loop/", dependencies=[Depends(verify_ip_address)])
async def feedback_loop(feedback: Feedback):
    thumbs_up = 'Yes' if feedback.thumbs_up else 'No'
    thumbs_down = 'Yes' if feedback.thumbs_down else 'No'
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    feedback_data = [timestamp, feedback.user_id, feedback.content, feedback.feedback, thumbs_up, thumbs_down]

    if feedback.model in ["bert", "fasttext"]:
        write_feedback(feedback_data, feedback.model)
    else:
        raise HTTPException(status_code=400, detail="Invalid model type")

    return {"message": "Feedback received"}


@app.get("/download-feedback/{model_name}", dependencies=[Depends(verify_ip_address)])
async def download_feedback(model_name: str):
    if model_name in ["bert", "fasttext"]:
        file_path = f"feedback_{model_name}.csv"
    else:
        raise HTTPException(status_code=400, detail="Invalid model name")
    return FileResponse(file_path, media_type='text/csv', filename=file_path)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)