parth parekh
commited on
Commit
·
3cfd7e3
1
Parent(s):
645ea59
added new model from xxparthparekhxx/ContactShieldAI
Browse files- app.py +6 -25
- contact_sharing_epoch_1.pth +3 -0
- predictor.py +103 -0
app.py
CHANGED
|
@@ -1,37 +1,17 @@
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel
|
| 3 |
import torch
|
| 4 |
-
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
| 5 |
from torch.nn.functional import softmax
|
| 6 |
import re
|
|
|
|
| 7 |
|
| 8 |
app = FastAPI(
|
| 9 |
title="Contact Information Detection API",
|
| 10 |
-
description="API for detecting contact information in text",
|
| 11 |
version="1.0.0",
|
| 12 |
docs_url="/"
|
| 13 |
)
|
| 14 |
|
| 15 |
-
class ContactDetector:
|
| 16 |
-
def __init__(self):
|
| 17 |
-
cache_dir = "/app/model_cache"
|
| 18 |
-
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=cache_dir)
|
| 19 |
-
self.model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2, cache_dir=cache_dir)
|
| 20 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
-
self.model.to(self.device)
|
| 22 |
-
self.model.eval()
|
| 23 |
-
|
| 24 |
-
def detect_contact_info(self, text):
|
| 25 |
-
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
| 26 |
-
with torch.no_grad():
|
| 27 |
-
outputs = self.model(**inputs)
|
| 28 |
-
probabilities = softmax(outputs.logits, dim=1)
|
| 29 |
-
return probabilities[0][1].item() # Probability of contact info
|
| 30 |
-
|
| 31 |
-
def is_contact_info(self, text, threshold=0.45):
|
| 32 |
-
return self.detect_contact_info(text) > threshold
|
| 33 |
-
|
| 34 |
-
detector = ContactDetector()
|
| 35 |
|
| 36 |
class TextInput(BaseModel):
|
| 37 |
text: str
|
|
@@ -65,9 +45,10 @@ async def detect_contact(input: TextInput):
|
|
| 65 |
"method": "regex"
|
| 66 |
}
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
| 71 |
return {
|
| 72 |
"text": input.text,
|
| 73 |
"contact_probability": probability,
|
|
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from pydantic import BaseModel
|
| 3 |
import torch
|
|
|
|
| 4 |
from torch.nn.functional import softmax
|
| 5 |
import re
|
| 6 |
+
from .predictor import predict
|
| 7 |
|
| 8 |
app = FastAPI(
|
| 9 |
title="Contact Information Detection API",
|
| 10 |
+
description="API for detecting contact information in text great thanks to xxparthparekhxx/ContactShieldAI for the model",
|
| 11 |
version="1.0.0",
|
| 12 |
docs_url="/"
|
| 13 |
)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class TextInput(BaseModel):
|
| 17 |
text: str
|
|
|
|
| 45 |
"method": "regex"
|
| 46 |
}
|
| 47 |
|
| 48 |
+
# If no regex patterns match, use the model
|
| 49 |
+
probabilities = predict(input.text)
|
| 50 |
+
probability = probabilities[1] # Probability of containing contact info
|
| 51 |
+
is_contact = probability > 0.5 # You can adjust this threshold as needed
|
| 52 |
return {
|
| 53 |
"text": input.text,
|
| 54 |
"contact_probability": probability,
|
contact_sharing_epoch_1.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bdb70e711c212856ce3df95b82afbae57b8fc34243b3f541ecd65963fa81fd92
|
| 3 |
+
size 813497259
|
predictor.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchtext.vocab import build_vocab_from_iterator, GloVe
|
| 5 |
+
from torchtext.data.utils import get_tokenizer
|
| 6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 7 |
+
|
| 8 |
+
class ContactSharingClassifier(nn.Module):
|
| 9 |
+
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, lstm_hidden_dim, output_dim, dropout, pad_idx):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
|
| 12 |
+
self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, bidirectional=True, batch_first=True)
|
| 13 |
+
self.convs = nn.ModuleList([
|
| 14 |
+
nn.Conv1d(in_channels=lstm_hidden_dim*2, out_channels=num_filters, kernel_size=fs)
|
| 15 |
+
for fs in filter_sizes
|
| 16 |
+
])
|
| 17 |
+
self.fc1 = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters // 2)
|
| 18 |
+
self.fc2 = nn.Linear(len(filter_sizes) * num_filters // 2, output_dim)
|
| 19 |
+
self.dropout = nn.Dropout(dropout)
|
| 20 |
+
self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters)
|
| 21 |
+
|
| 22 |
+
def forward(self, text):
|
| 23 |
+
embedded = self.embedding(text)
|
| 24 |
+
lstm_out, _ = self.lstm(embedded)
|
| 25 |
+
lstm_out = lstm_out.permute(0, 2, 1)
|
| 26 |
+
conved = [F.relu(conv(lstm_out)) for conv in self.convs]
|
| 27 |
+
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
|
| 28 |
+
cat = self.dropout(torch.cat(pooled, dim=1))
|
| 29 |
+
cat = self.layer_norm(cat)
|
| 30 |
+
x = F.relu(self.fc1(cat))
|
| 31 |
+
x = self.dropout(x)
|
| 32 |
+
return self.fc2(x)
|
| 33 |
+
|
| 34 |
+
# Initialize tokenizer and vocabulary
|
| 35 |
+
tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
|
| 36 |
+
vocab = torch.load('vocab.pth') # Assuming you've saved the vocabulary
|
| 37 |
+
|
| 38 |
+
# Define text pipeline
|
| 39 |
+
def text_pipeline(x):
|
| 40 |
+
return [vocab[token] for token in tokenizer(x)]
|
| 41 |
+
|
| 42 |
+
# Model parameters
|
| 43 |
+
VOCAB_SIZE = len(vocab)
|
| 44 |
+
EMBED_DIM = 600
|
| 45 |
+
NUM_FILTERS = 600
|
| 46 |
+
FILTER_SIZES = [3, 4, 5, 6, 7, 8, 9, 10]
|
| 47 |
+
LSTM_HIDDEN_DIM = 768
|
| 48 |
+
OUTPUT_DIM = 2
|
| 49 |
+
DROPOUT = 0.5
|
| 50 |
+
PAD_IDX = vocab["<pad>"]
|
| 51 |
+
|
| 52 |
+
# Load the model
|
| 53 |
+
|
| 54 |
+
model = ContactSharingClassifier(VOCAB_SIZE, EMBED_DIM, NUM_FILTERS, FILTER_SIZES, LSTM_HIDDEN_DIM, OUTPUT_DIM, DROPOUT, PAD_IDX)
|
| 55 |
+
model.load_state_dict(torch.load('contact_sharing_epoch_1.pth', map_location=device))
|
| 56 |
+
model.to(device)
|
| 57 |
+
model.eval()
|
| 58 |
+
|
| 59 |
+
# Test sentences
|
| 60 |
+
test_sentences = [
|
| 61 |
+
"You can reach me at my electronic mail address, it's my first name dot last name at that popular search engine company's mail service.",
|
| 62 |
+
"Call me on my cellular device, the digits are the same as the year the Declaration of Independence was signed, followed by my birth year, twice.",
|
| 63 |
+
"Visit my online presence at triple w dot my full name without spaces or punctuation dot com.",
|
| 64 |
+
"Send a message to username 'not_my_real_name' on that instant messaging platform that starts with 'disc' and ends with 'ord'.",
|
| 65 |
+
"My contact info is hidden in this sentence: Eight Six Seven Five Three Oh Nine.",
|
| 66 |
+
"Find me on the professional networking site, just search for my name plus 'software engineer in San Francisco'.",
|
| 67 |
+
"My handle on the bird-themed social media platform is at symbol followed by 'definitely_not_my_email_address'.",
|
| 68 |
+
"You know that video sharing site? My channel is there, just add 'cool_coder_' before my full name, all lowercase.",
|
| 69 |
+
"I'm listed in the phone book under 'Smith, John' but replace 'Smith' with my actual last name and 'John' with my first name.",
|
| 70 |
+
"My contact details are encrypted: Rot13('zl.rznvy@tznvy.pbz')",
|
| 71 |
+
|
| 72 |
+
# New non-contact sharing examples
|
| 73 |
+
"The weather today is absolutely beautiful, perfect for a picnic in the park.",
|
| 74 |
+
"I'm really excited about the new sci-fi movie coming out next month.",
|
| 75 |
+
"Did you hear about the latest advancements in artificial intelligence? It's fascinating!",
|
| 76 |
+
"I'm planning to go hiking this weekend in the nearby mountains.",
|
| 77 |
+
"The recipe calls for two cups of flour and a pinch of salt.",
|
| 78 |
+
"The annual tech conference will be held virtually this year due to ongoing health concerns.",
|
| 79 |
+
"I've been learning to play the guitar for the past six months. It's challenging but rewarding.",
|
| 80 |
+
"The local farmer's market has the freshest produce every Saturday morning.",
|
| 81 |
+
"Did you catch the game last night? It was an incredible comeback in the final quarter!",
|
| 82 |
+
"Lets do '42069' tonight it will be really fun what do you say ?"
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def predict(text):
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
inputs = torch.tensor([text_pipeline(text)])
|
| 89 |
+
if inputs.size(1) < max(FILTER_SIZES):
|
| 90 |
+
padding = torch.zeros(1, max(FILTER_SIZES) - inputs.size(1), dtype=torch.long)
|
| 91 |
+
inputs = torch.cat([inputs, padding], dim=1)
|
| 92 |
+
inputs = inputs.to(device)
|
| 93 |
+
outputs = model(inputs)
|
| 94 |
+
probabilities = F.softmax(outputs, dim=1)
|
| 95 |
+
return probabilities.squeeze().tolist()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Test the sentences
|
| 99 |
+
for i, sentence in enumerate(test_sentences, 1):
|
| 100 |
+
prediction = predict(sentence)
|
| 101 |
+
result = "Contains contact info" if prediction == 1 else "No contact info"
|
| 102 |
+
print(f"Sentence {i}: {result}")
|
| 103 |
+
print(f"Text: {sentence}\n")
|