ESG_Analysis / model /zero_shot_classification.py
Mawube's picture
Update model/zero_shot_classification.py
41317cd
from model.config import config as cfg
from transformers import pipeline
from typing import Dict
# Zero shot classifier
class ZeroShotClassifier:
def __init__(
self,
model_name: str = cfg.model_name,
pipeline_name: str = cfg.pipeline_name,
):
self.model_name = model_name
self.pipeline_name = pipeline_name
self.labels = cfg.labels
def get_pipeline(self):
if self.pipeline_name == "zero-shot-classification":
return pipeline(self.pipeline_name, model=self.model_name)
else:
raise ValueError("Invalid pipeline name")
# predicts scores for each label
def Predict(self,text: str)-> Dict[str, float]:
model = self.get_pipeline()
sentiments = model(text, cfg.labels, multi_label=True)
print(sentiments)
a = sentiments['scores'][sentiments['labels'].index('social')]
b = sentiments['scores'][sentiments['labels'].index('environmental')]
c = sentiments['scores'][sentiments['labels'].index('governance')]
result = {
'social':a,
'environmental':b,
'governance':c
}
return result