|
|
| import torch
|
| from transformers import Pipeline
|
|
|
|
|
| class GPT124MTextGenerationPipeline(Pipeline):
|
|
|
| def _sanitize_parameters(self, **kwargs):
|
| """
|
| Organizes and sanitizes input parameters into separate dictionaries for:
|
| - Preprocessing (encoding)
|
| - Model forward pass (generation settings)
|
| - Postprocessing (decoding)
|
| """
|
|
|
| preprocess_kwargs = {}
|
| forward_kwargs = {
|
| "max_length": kwargs.get("max_length", 50),
|
| "do_sample": kwargs.get("do_sample", True),
|
| "top_k": kwargs.get("top_k", 50),
|
| "top_p": kwargs.get("top_p", 0.95),
|
| "temperature": kwargs.get("temperature", 0.9),
|
| "device": kwargs.get("device", "cpu"),
|
| }
|
| postprocess_kwargs = {}
|
|
|
| return preprocess_kwargs, forward_kwargs, postprocess_kwargs
|
|
|
| def preprocess(self, prompt_text: str, **preprocess_kwargs):
|
| """
|
| Encodes input text into token IDs using the tokenizer and converts it to a PyTorch tensor.
|
| """
|
|
|
| assert (
|
| isinstance(prompt_text, str) and len(prompt_text) > 0
|
| ), "prompt_text must be a non-empty string"
|
|
|
|
|
| input_ids = self.tokenizer.encode(prompt_text)
|
|
|
|
|
| input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
|
|
| return {"input_ids": input_tensor}
|
|
|
| def _forward(self, model_inputs, **forward_kwargs):
|
| """
|
| Forwards the tokenized input to the model's generate method.
|
| """
|
|
|
| return self.model.generate(**model_inputs, **forward_kwargs)
|
|
|
| def postprocess(self, model_output, **postprocess_kwargs):
|
| """
|
| Decodes token ID into human-readable text using the tokenizer.
|
| """
|
|
|
| return self.tokenizer.decode(model_output)
|
|
|