File size: 3,832 Bytes
7933846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import transformers
from transformers import Pipeline

try:
    import orbitals.scope_guard
    import orbitals.scope_guard.modeling
    import orbitals.scope_guard.prompting
    import orbitals.types
except ModuleNotFoundError:
    raise ImportError(
        "orbitals.scope_guard module not found. Please install it: `pip install orbitals`"
    )


class ScopeGuardPipeline(Pipeline):
    def __init__(
        self,
        model,
        tokenizer=None,
        skip_evidences: bool = False,
        max_new_tokens: int = 1024,
        do_sample: bool = False,
        **kwargs,
    ):
        if tokenizer is None and isinstance(model, str):
            tokenizer = transformers.AutoTokenizer.from_pretrained(model)
        elif isinstance(tokenizer, str):
            tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)

        if isinstance(model, str):
            model = transformers.AutoModelForCausalLM.from_pretrained(
                model, dtype="auto", device_map="auto"
            )

        # Set left padding for decoder-only models (required for batched generation)
        if tokenizer is not None:
            tokenizer.padding_side = "left"
            # Ensure pad token is set (use eos_token if pad_token doesn't exist)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

        self.skip_evidences = skip_evidences
        self.max_new_tokens = max_new_tokens
        self.do_sample = do_sample

        super().__init__(model, tokenizer, **kwargs)

    def _sanitize_parameters(
        self,
        **kwargs,
    ):
        preprocess_kwargs = {}
        if "skip_evidences" in kwargs or self.skip_evidences:
            preprocess_kwargs["skip_evidences"] = kwargs.get(
                "skip_evidences", self.skip_evidences
            )

        return (
            preprocess_kwargs,
            {},
            {},
        )

    def preprocess(
        self,
        inputs: tuple[
            orbitals.scope_guard.modeling.ScopeGuardInput,
            str | orbitals.types.AIServiceDescription,
        ],
        skip_evidences: bool = False,
    ):
        conversation, ai_service_description = inputs

        model_messages = orbitals.scope_guard.prompting.prepare_messages(
            conversation,
            ai_service_description,
            skip_evidences,
        )

        text = self.tokenizer.apply_chat_template(
            model_messages,
            tokenize=False,  # we are not tokenizing so as to enable batching
            add_generation_prompt=True,
            enable_thinking=False,
        )

        return {"text": text}

    def _forward(self, model_inputs):
        tokenized = self.tokenizer(
            model_inputs["text"],
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to(self.device)

        with torch.inference_mode():
            outputs = self.model.generate(
                **tokenized,
                max_new_tokens=self.max_new_tokens,
                do_sample=self.do_sample,
            )
        return {
            "output_ids": outputs,
            "input_ids": tokenized["input_ids"],
        }

    def postprocess(self, model_outputs):
        output_ids = model_outputs["output_ids"]
        input_ids = model_outputs["input_ids"]

        # Decode each output in the batch
        results = []
        for i in range(output_ids.shape[0]):
            # Skip the input tokens to get only the generated text
            generated_ids = output_ids[i][input_ids.shape[1] :]
            generated_output = self.tokenizer.decode(
                generated_ids,
                skip_special_tokens=True,
            )
            results.append({"generated_text": generated_output})

        return results