samkeet commited on
Commit
bec6f13
·
verified ·
1 Parent(s): 1c9c1e1

Upload GPT124MTextGenerationPipeline

Browse files
Files changed (2) hide show
  1. config.json +15 -0
  2. pipeline_gpt.py +59 -0
config.json CHANGED
@@ -7,6 +7,21 @@
7
  "AutoModelForCausalLM": "modeling_gpt.GPTModelForTextGeneration"
8
  },
9
  "block_size": 1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  "model_type": "custom_gpt",
11
  "n_embd": 768,
12
  "n_head": 12,
 
7
  "AutoModelForCausalLM": "modeling_gpt.GPTModelForTextGeneration"
8
  },
9
  "block_size": 1024,
10
+ "custom_pipelines": {
11
+ "text-generation": {
12
+ "default": {
13
+ "model": {
14
+ "pt": "samkeet/GPT_124M-Instruct"
15
+ }
16
+ },
17
+ "impl": "pipeline_gpt.GPT124MTextGenerationPipeline",
18
+ "pt": [
19
+ "AutoModelForCausalLM"
20
+ ],
21
+ "tf": [],
22
+ "type": "text"
23
+ }
24
+ },
25
  "model_type": "custom_gpt",
26
  "n_embd": 768,
27
  "n_head": 12,
pipeline_gpt.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import torch
3
+ from transformers import Pipeline
4
+ from .modeling_gpt import GPTModelForTextGeneration
5
+
6
+
7
+ class GPT124MTextGenerationPipeline(Pipeline):
8
+
9
+ def _sanitize_parameters(self, **kwargs):
10
+ """
11
+ Organizes and sanitizes input parameters into separate dictionaries for:
12
+ - Preprocessing (encoding)
13
+ - Model forward pass (generation settings)
14
+ - Postprocessing (decoding)
15
+ """
16
+
17
+ preprocess_kwargs = {}
18
+ forward_kwargs = {
19
+ "max_length": kwargs.get("max_length", 50),
20
+ "do_sample": kwargs.get("do_sample", True),
21
+ "top_k": kwargs.get("top_k", 50),
22
+ "top_p": kwargs.get("top_p", 0.95),
23
+ "temperature": kwargs.get("temperature", 0.9),
24
+ "device": kwargs.get("device", self.device.type),
25
+ }
26
+ postprocess_kwargs = {}
27
+
28
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
29
+
30
+ def preprocess(self, prompt_text: str, **preprocess_kwargs):
31
+ """
32
+ Encodes input text into token IDs using the tokenizer and converts it to a PyTorch tensor.
33
+ """
34
+
35
+ assert (
36
+ isinstance(prompt_text, str) and len(prompt_text) > 0
37
+ ), "prompt_text must be a non-empty string"
38
+
39
+ # Encode the input text
40
+ input_ids = self.tokenizer.encode(prompt_text)
41
+
42
+ # Convert to a PyTorch tensor
43
+ input_tensor = torch.tensor([input_ids])
44
+
45
+ return {"input_ids": input_tensor}
46
+
47
+ def _forward(self, model_inputs, **forward_kwargs):
48
+ """
49
+ Forwards the tokenized input to the model's generate method.
50
+ """
51
+
52
+ return self.model.generate(**model_inputs, **forward_kwargs)
53
+
54
+ def postprocess(self, model_output, **postprocess_kwargs):
55
+ """
56
+ Decodes token ID into human-readable text using the tokenizer.
57
+ """
58
+
59
+ return self.tokenizer.decode(model_output)