lucabadiali
Added env config file
85c00b7
from pathlib import Path
from .config import ModelSource, HF_MODEL
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset as hf_load_dataset
from datasets import load_from_disk
def preprocess(text:str)->str:
"""
Returns an input text ready to be tokenized by removing special characters
"""
new_text = []
for t in text.split(" "):
t = '@user' if t.startswith('@') and len(t) > 1 else t
t = 'http' if t.startswith('http') else t
new_text.append(t)
return " ".join(new_text)
def load_model_and_tokenizer(MODEL_SOURCE:str)->(AutoTokenizer,AutoModelForSequenceClassification):
"""
Loads a tokenizer and sentiment analysis model. These can be either loaded from local
or downloaded from Hugging Face API
"""
if MODEL_SOURCE == ModelSource.HF: # use the latest model available in the HF hub
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL)
else: # use a locally fine tuned model
local_model_path = Path("models/saved_model")
assert local_model_path.exists(), """No local model was found. Run 'python3 src/train_model.py' first"""
tokenizer = AutoTokenizer.from_pretrained("models/saved_tokenizer")
model = AutoModelForSequenceClassification.from_pretrained("models/saved_model")
return tokenizer, model
def load_dataset(dataset_path:str):
"""
Loads the tweet_eval dataset for sentiment analysis task. The dataset
can be either loaded from local and downloaded through Hugging Face API
"""
if dataset_path.exists():
dataset = load_from_disk(dataset_path)
else:
dataset = hf_load_dataset('tweet_eval', 'sentiment')
return dataset