File size: 1,256 Bytes
e83e0cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41317cd
e83e0cf
 
 
 
 
 
 
 
 
41317cd
e83e0cf
41317cd
e83e0cf
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from model.config import config as cfg
from transformers import pipeline
from typing import Dict


# Zero shot classifier
class ZeroShotClassifier:

    def __init__(
        self,
        model_name: str =  cfg.model_name, 
        pipeline_name: str =  cfg.pipeline_name,
        ):
        self.model_name = model_name
        self.pipeline_name = pipeline_name
        self.labels = cfg.labels


    def get_pipeline(self):
        if self.pipeline_name == "zero-shot-classification":
            return pipeline(self.pipeline_name, model=self.model_name)

        else:
            raise ValueError("Invalid pipeline name")


    # predicts scores for each label
    def Predict(self,text: str)-> Dict[str, float]:
            model = self.get_pipeline()
            sentiments = model(text, cfg.labels, multi_label=True)
            print(sentiments)
            a = sentiments['scores'][sentiments['labels'].index('social')]
            b = sentiments['scores'][sentiments['labels'].index('environmental')]
            c = sentiments['scores'][sentiments['labels'].index('governance')]

            result = {
                'social':a,
                'environmental':b,
                'governance':c
            }

            return result