ad-mod / model.py
venishpatidar's picture
Upload 7 files
be663b5 verified
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
import torch
class AdsMod:
"""
Model init method
base model used = bloom-560m
model_path: takes the path of saved model weights
"""
def __init__(self,model_path:str="./") -> None:
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenize = lambda input: tokenizer(input, truncation=True, padding="max_length", max_length=256,return_tensors="pt").to(device)
self.model.to(device)
"""
predict
takes text as input and classifies the text
returns 0 or 1
"""
def predict(self,text:str) -> int:
input = self.tokenize(text)
with torch.no_grad():
output = self.model(**input)
predicted_class = torch.argmax(output.logits, dim=1).item()
return predicted_class
if __name__=="__main__":
model = AdsMod('venishpatidar/wa-ad-mod')
text = "Hi I am text classifier, will help you in deleteing housing ads message"
print("Ad" if model.predict(text) else "Safe")