File size: 1,800 Bytes
2cab2c4
7442f6e
2cab2c4
 
 
 
 
 
 
 
 
 
 
a5588f7
 
 
 
 
 
 
84a0e50
 
2cab2c4
a5588f7
 
 
 
 
 
 
 
 
 
 
 
 
5810bcc
a5588f7
 
 
868047f
84a0e50
 
 
a5588f7
868047f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import  Dict, List, Any
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch


class EndpointHandler():
    def __init__(self, path=""):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
        self.sentiment_model = AutoModelForSequenceClassification.from_pretrained("Christian2903/amazon-review-sentiment-analysis").to(self.device)


    def __call__(self, data: Dict[str, List[str]]) -> Dict[str, List[int]]:
        """
         data args:
              reviews (:obj: `str`)
        Return:
              A `dict`: will be serialized and returned
        """

        reviews = data.pop("reviews", data)

        # inputs = self.tokenizer(reviews, return_tensors="pt", truncation=True, padding="max_length", max_length=256).to(self.device)
        # outputs = self.sentiment_model(**inputs)
        # logits = outputs.logits.detach()
        # predicted_scores = [max(min(int(score + 0.5),5),1) for score in logits]

        batch_size = 32
        predictions = []
        for i in range(0, len(reviews), batch_size):
            batch = reviews[i:i+batch_size]
            # inputs = tokenizer(batch, truncation=True, padding="max_length", max_length=256, return_tensors="pt")
            inputs = self.tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length", max_length=256).to(self.device)
            outputs = self.sentiment_model(**inputs)
            logits = outputs[0]
            predictions.extend(logits.detach().cpu().numpy())


        predicted_scores = [max(min(int(score + 0.5),5),1) for score in predictions]

        response = {
            'scores': predicted_scores
        }

        return response