File size: 1,823 Bytes
57fbf67
 
 
 
 
 
 
85c00b7
 
 
 
b603fd0
 
 
 
 
57fbf67
 
 
85c00b7
 
 
 
 
57fbf67
 
 
 
 
 
 
 
 
 
 
85c00b7
 
 
 
 
57fbf67
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from pathlib import Path
from .config import ModelSource, HF_MODEL
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset as hf_load_dataset
from datasets import load_from_disk


def preprocess(text:str)->str:
    """
    Returns an input text ready to be tokenized by removing special characters
    """
    new_text = []
    for t in text.split(" "):
        t = '@user' if t.startswith('@') and len(t) > 1 else t
        t = 'http' if t.startswith('http') else t
        new_text.append(t)
    return " ".join(new_text)


def load_model_and_tokenizer(MODEL_SOURCE:str)->(AutoTokenizer,AutoModelForSequenceClassification):
    """
    Loads a tokenizer and sentiment analysis model. These can be either loaded from local
    or downloaded from Hugging Face API
    """
    if MODEL_SOURCE == ModelSource.HF:   # use the latest model available in the HF hub
        tokenizer = AutoTokenizer.from_pretrained(HF_MODEL)
        model = AutoModelForSequenceClassification.from_pretrained(HF_MODEL)
    else: # use a locally fine tuned model
        local_model_path = Path("models/saved_model")
        assert local_model_path.exists(), """No local model was found. Run 'python3 src/train_model.py' first"""
        tokenizer = AutoTokenizer.from_pretrained("models/saved_tokenizer")
        model = AutoModelForSequenceClassification.from_pretrained("models/saved_model")
    return tokenizer, model


def load_dataset(dataset_path:str):
    """
    Loads the tweet_eval dataset for sentiment analysis task. The dataset
    can be either loaded from local and downloaded through Hugging Face API
    """
    if dataset_path.exists():
        dataset = load_from_disk(dataset_path)
    else:
        dataset = hf_load_dataset('tweet_eval', 'sentiment')
    return dataset