forrestbao's picture
init push
4e7ea3a
raw
history blame
1.61 kB
# %%
from typing import List, Literal
from pydantic import BaseModel
from IPython.display import display, Markdown
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
# %%
class HHEMOutput(BaseModel):
score: float # we need score for ROC curve
label: Literal[0,1]
# %%
PROMPT_TEMPLATE = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}"
CHECKPOINT = "vectara/hallucination_evaluation_model"
FOUNDATION = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(FOUNDATION)
classifier = pipeline("text-classification", model=CHECKPOINT, tokenizer=tokenizer, trust_remote_code=True)
def predict(premise: str, hypothesis: str) -> Markdown:
texts_prompted: List[str] = [PROMPT_TEMPLATE.format(text1=premise, text2=hypothesis)]
full_scores = classifier(texts_prompted, top_k=None) # List[List[Dict[str, float]]]
# Optional: Extract the scores for the 'consistent' label
simple_scores = [score_dict['score'] for score_for_both_labels in full_scores for score_dict in score_for_both_labels if score_dict['label'] == 'consistent']
threshold = 0.5
preds = [0 if s < threshold else 1 for s in simple_scores]
output = HHEMOutput(score=simple_scores[0], label=preds[0])
verdict = "consistent" if output.label == 1 else "hallucinated"
output_string = f"""
**Premise**: {premise}
**Hypothesis**: {hypothesis}
**HHEM's judgement is**: {verdict} **with the score**: {output.score}
"""
return Markdown(output_string)