File size: 4,320 Bytes
d420a64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a79360
d420a64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import transformers
from transformers import Pipeline

try:
    import orbitals.claim_extractor
    import orbitals.claim_extractor.modeling
    import orbitals.claim_extractor.prompting
    import orbitals.types
except ModuleNotFoundError:
    raise ImportError(
        "orbitals.claim_extractor module not found. Please install it: `pip install orbitals`"
    )


class ClaimExtractionPipeline(Pipeline):
    def __init__(
        self,
        model,
        tokenizer=None,
        skip_evidences: bool = True,
        max_new_tokens: int = 20_000,
        do_sample: bool = True,
        temperature: float = 0.7,
        repetition_penalty: float = 1.0,
        top_p: float = 0.8,
        top_k: int = 20,
        min_p: float = 0.0,
        **kwargs,
    ):
        if tokenizer is None and isinstance(model, str):
            tokenizer = transformers.AutoTokenizer.from_pretrained(model)
        elif isinstance(tokenizer, str):
            tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)

        if isinstance(model, str):
            model = transformers.AutoModelForCausalLM.from_pretrained(
                model, dtype="auto", device_map="auto"
            )

        # Set left padding for decoder-only models (required for batched generation)
        if tokenizer is not None:
            tokenizer.padding_side = "left"
            # Ensure pad token is set (use eos_token if pad_token doesn't exist)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

        self.skip_evidences = skip_evidences
        self.max_new_tokens = max_new_tokens
        self.do_sample = do_sample
        self.temperature = temperature
        self.repetition_penalty = repetition_penalty
        self.top_p = top_p
        self.top_k = top_k
        self.min_p = min_p

        super().__init__(model, tokenizer, **kwargs)

    def _sanitize_parameters(
        self,
        **kwargs,
    ):
        preprocess_kwargs = {
            "skip_evidences": kwargs.get("skip_evidences", self.skip_evidences)
        }

        return (
            preprocess_kwargs,
            {},
            {},
        )

    def preprocess(
        self,
        inputs: tuple[
            orbitals.claim_extractor.modeling.ClaimExtractorInput,
            str | orbitals.types.AIServiceDescription | None,
        ],
        skip_evidences: bool = True,
    ):
        conversation, ai_service_description = inputs

        model_messages = orbitals.claim_extractor.prompting.prepare_messages(
            conversation,
            ai_service_description,
            skip_evidences=skip_evidences,
        )

        text = self.tokenizer.apply_chat_template(
            model_messages,
            tokenize=False,  # we are not tokenizing so as to enable batching
            add_generation_prompt=True,
            enable_thinking=False,
        )

        return {"text": text}

    def _forward(self, model_inputs):
        tokenized = self.tokenizer(
            model_inputs["text"],
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to(self.device)

        with torch.inference_mode():
            outputs = self.model.generate(
                **tokenized,
                max_new_tokens=self.max_new_tokens,
                do_sample=self.do_sample,
                temperature=self.temperature,
                repetition_penalty=self.repetition_penalty,
                top_p=self.top_p,
                top_k=self.top_k,
                min_p=self.min_p,
            )
        return {
            "output_ids": outputs,
            "input_ids": tokenized["input_ids"],
        }

    def postprocess(self, model_outputs):
        output_ids = model_outputs["output_ids"]
        input_ids = model_outputs["input_ids"]

        # Decode each output in the batch
        results = []
        for i in range(output_ids.shape[0]):
            # Skip the input tokens to get only the generated text
            generated_ids = output_ids[i][input_ids.shape[1] :]
            generated_output = self.tokenizer.decode(
                generated_ids,
                skip_special_tokens=True,
            )
            results.append({"generated_text": generated_output})

        return results