File size: 3,832 Bytes
7933846 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
import transformers
from transformers import Pipeline
try:
import orbitals.scope_guard
import orbitals.scope_guard.modeling
import orbitals.scope_guard.prompting
import orbitals.types
except ModuleNotFoundError:
raise ImportError(
"orbitals.scope_guard module not found. Please install it: `pip install orbitals`"
)
class ScopeGuardPipeline(Pipeline):
def __init__(
self,
model,
tokenizer=None,
skip_evidences: bool = False,
max_new_tokens: int = 1024,
do_sample: bool = False,
**kwargs,
):
if tokenizer is None and isinstance(model, str):
tokenizer = transformers.AutoTokenizer.from_pretrained(model)
elif isinstance(tokenizer, str):
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
if isinstance(model, str):
model = transformers.AutoModelForCausalLM.from_pretrained(
model, dtype="auto", device_map="auto"
)
# Set left padding for decoder-only models (required for batched generation)
if tokenizer is not None:
tokenizer.padding_side = "left"
# Ensure pad token is set (use eos_token if pad_token doesn't exist)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
self.skip_evidences = skip_evidences
self.max_new_tokens = max_new_tokens
self.do_sample = do_sample
super().__init__(model, tokenizer, **kwargs)
def _sanitize_parameters(
self,
**kwargs,
):
preprocess_kwargs = {}
if "skip_evidences" in kwargs or self.skip_evidences:
preprocess_kwargs["skip_evidences"] = kwargs.get(
"skip_evidences", self.skip_evidences
)
return (
preprocess_kwargs,
{},
{},
)
def preprocess(
self,
inputs: tuple[
orbitals.scope_guard.modeling.ScopeGuardInput,
str | orbitals.types.AIServiceDescription,
],
skip_evidences: bool = False,
):
conversation, ai_service_description = inputs
model_messages = orbitals.scope_guard.prompting.prepare_messages(
conversation,
ai_service_description,
skip_evidences,
)
text = self.tokenizer.apply_chat_template(
model_messages,
tokenize=False, # we are not tokenizing so as to enable batching
add_generation_prompt=True,
enable_thinking=False,
)
return {"text": text}
def _forward(self, model_inputs):
tokenized = self.tokenizer(
model_inputs["text"],
return_tensors="pt",
padding=True,
truncation=True,
).to(self.device)
with torch.inference_mode():
outputs = self.model.generate(
**tokenized,
max_new_tokens=self.max_new_tokens,
do_sample=self.do_sample,
)
return {
"output_ids": outputs,
"input_ids": tokenized["input_ids"],
}
def postprocess(self, model_outputs):
output_ids = model_outputs["output_ids"]
input_ids = model_outputs["input_ids"]
# Decode each output in the batch
results = []
for i in range(output_ids.shape[0]):
# Skip the input tokens to get only the generated text
generated_ids = output_ids[i][input_ids.shape[1] :]
generated_output = self.tokenizer.decode(
generated_ids,
skip_special_tokens=True,
)
results.append({"generated_text": generated_output})
return results
|