Spaces:
Running
Running
File size: 1,823 Bytes
57fbf67 85c00b7 b603fd0 57fbf67 85c00b7 57fbf67 85c00b7 57fbf67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from pathlib import Path
from .config import ModelSource, HF_MODEL
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset as hf_load_dataset
from datasets import load_from_disk
def preprocess(text:str)->str:
"""
Returns an input text ready to be tokenized by removing special characters
"""
new_text = []
for t in text.split(" "):
t = '@user' if t.startswith('@') and len(t) > 1 else t
t = 'http' if t.startswith('http') else t
new_text.append(t)
return " ".join(new_text)
def load_model_and_tokenizer(MODEL_SOURCE:str)->(AutoTokenizer,AutoModelForSequenceClassification):
"""
Loads a tokenizer and sentiment analysis model. These can be either loaded from local
or downloaded from Hugging Face API
"""
if MODEL_SOURCE == ModelSource.HF: # use the latest model available in the HF hub
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL)
else: # use a locally fine tuned model
local_model_path = Path("models/saved_model")
assert local_model_path.exists(), """No local model was found. Run 'python3 src/train_model.py' first"""
tokenizer = AutoTokenizer.from_pretrained("models/saved_tokenizer")
model = AutoModelForSequenceClassification.from_pretrained("models/saved_model")
return tokenizer, model
def load_dataset(dataset_path:str):
"""
Loads the tweet_eval dataset for sentiment analysis task. The dataset
can be either loaded from local and downloaded through Hugging Face API
"""
if dataset_path.exists():
dataset = load_from_disk(dataset_path)
else:
dataset = hf_load_dataset('tweet_eval', 'sentiment')
return dataset |