File size: 1,849 Bytes
88b3233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# %%
# ----------------------------------------------------------
# Custom Hugging-Face pipeline for the “tossup” split that only refers to the existing models
# Task id  :  quizbowl-tossup
# Expected input keys : question_text
# Must return        : answer, confidence, buzz
# ----------------------------------------------------------

import torch
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    Pipeline,
    TFAutoModelForCausalLM,
)
from transformers.pipelines import PIPELINE_REGISTRY


class QBTossupPipeline(Pipeline):
    def _sanitize_parameters(self, **kwargs):
        # We don't need any additional parameters for this pipeline
        return {}, {}, {}

    def preprocess(self, inputs):
        prompt = (
            "Answer the quiz question revealed so far:\n"
            f"Question: {inputs['question_text']}\nAnswer:"
        )
        return self.tokenizer(prompt, return_tensors="pt", truncation=True)

    def _forward(self, model_inputs):
        with torch.no_grad():
            return self.model(**model_inputs)

    def postprocess(self, model_outputs):
        logits = model_outputs.logits  # shape: (1, seq_len, vocab)
        last_logits = logits[0, -1]  # take distribution for last token
        probs = F.softmax(last_logits, dim=-1)
        top_id = torch.argmax(probs)
        answer = self.tokenizer.decode(top_id, skip_special_tokens=True).strip()
        confidence = probs[top_id].item()
        buzz = confidence > 0.5
        return {"answer": answer, "confidence": confidence, "buzz": buzz}


PIPELINE_REGISTRY.register_pipeline(
    "quizbowl-tossup",
    pipeline_class=QBTossupPipeline,
    pt_model=AutoModelForCausalLM,
    tf_model=TFAutoModelForCausalLM,
    default={
        "pt": ("yujiepan/llama-3-tiny-random", "main"),
    },
    type="text",
)