|
|
import torch |
|
|
import transformers |
|
|
from transformers import Pipeline |
|
|
|
|
|
try: |
|
|
import orbitals.scope_guard |
|
|
import orbitals.scope_guard.modeling |
|
|
import orbitals.scope_guard.prompting |
|
|
import orbitals.types |
|
|
except ModuleNotFoundError: |
|
|
raise ImportError( |
|
|
"orbitals.scope_guard module not found. Please install it: `pip install orbitals`" |
|
|
) |
|
|
|
|
|
|
|
|
class ScopeGuardPipeline(Pipeline): |
|
|
def __init__( |
|
|
self, |
|
|
model, |
|
|
tokenizer=None, |
|
|
skip_evidences: bool = False, |
|
|
max_new_tokens: int = 1024, |
|
|
do_sample: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
if tokenizer is None and isinstance(model, str): |
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model) |
|
|
elif isinstance(tokenizer, str): |
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) |
|
|
|
|
|
if isinstance(model, str): |
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
|
model, dtype="auto", device_map="auto" |
|
|
) |
|
|
|
|
|
|
|
|
if tokenizer is not None: |
|
|
tokenizer.padding_side = "left" |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
self.skip_evidences = skip_evidences |
|
|
self.max_new_tokens = max_new_tokens |
|
|
self.do_sample = do_sample |
|
|
|
|
|
super().__init__(model, tokenizer, **kwargs) |
|
|
|
|
|
def _sanitize_parameters( |
|
|
self, |
|
|
**kwargs, |
|
|
): |
|
|
preprocess_kwargs = {} |
|
|
if "skip_evidences" in kwargs or self.skip_evidences: |
|
|
preprocess_kwargs["skip_evidences"] = kwargs.get( |
|
|
"skip_evidences", self.skip_evidences |
|
|
) |
|
|
|
|
|
return ( |
|
|
preprocess_kwargs, |
|
|
{}, |
|
|
{}, |
|
|
) |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
inputs: tuple[ |
|
|
orbitals.scope_guard.modeling.ScopeGuardInput, |
|
|
str | orbitals.types.AIServiceDescription, |
|
|
], |
|
|
skip_evidences: bool = False, |
|
|
): |
|
|
conversation, ai_service_description = inputs |
|
|
|
|
|
model_messages = orbitals.scope_guard.prompting.prepare_messages( |
|
|
conversation, |
|
|
ai_service_description, |
|
|
skip_evidences, |
|
|
) |
|
|
|
|
|
text = self.tokenizer.apply_chat_template( |
|
|
model_messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
enable_thinking=False, |
|
|
) |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
def _forward(self, model_inputs): |
|
|
tokenized = self.tokenizer( |
|
|
model_inputs["text"], |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
).to(self.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = self.model.generate( |
|
|
**tokenized, |
|
|
max_new_tokens=self.max_new_tokens, |
|
|
do_sample=self.do_sample, |
|
|
) |
|
|
return { |
|
|
"output_ids": outputs, |
|
|
"input_ids": tokenized["input_ids"], |
|
|
} |
|
|
|
|
|
def postprocess(self, model_outputs): |
|
|
output_ids = model_outputs["output_ids"] |
|
|
input_ids = model_outputs["input_ids"] |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i in range(output_ids.shape[0]): |
|
|
|
|
|
generated_ids = output_ids[i][input_ids.shape[1] :] |
|
|
generated_output = self.tokenizer.decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True, |
|
|
) |
|
|
results.append({"generated_text": generated_output}) |
|
|
|
|
|
return results |
|
|
|