Trouter-Library commited on
Commit
33fc213
·
verified ·
1 Parent(s): e711b26

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +132 -0
inference.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for Kat-Gen1 model
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from typing import Optional, List
8
+
9
+
10
+ class KatGen1Inference:
11
+ def __init__(self, model_name: str = "Katisim/Kat-Gen1", device: Optional[str] = None):
12
+ """
13
+ Initialize the Kat-Gen1 model for inference.
14
+
15
+ Args:
16
+ model_name: HuggingFace model identifier
17
+ device: Device to run inference on (cuda/cpu)
18
+ """
19
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
23
+ ).to(self.device)
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+
26
+ if self.tokenizer.pad_token is None:
27
+ self.tokenizer.pad_token = self.tokenizer.eos_token
28
+
29
+ def generate(
30
+ self,
31
+ prompt: str,
32
+ max_length: int = 100,
33
+ temperature: float = 0.7,
34
+ top_p: float = 0.9,
35
+ top_k: int = 50,
36
+ num_return_sequences: int = 1,
37
+ do_sample: bool = True
38
+ ) -> List[str]:
39
+ """
40
+ Generate text from a prompt.
41
+
42
+ Args:
43
+ prompt: Input text prompt
44
+ max_length: Maximum length of generated text
45
+ temperature: Sampling temperature
46
+ top_p: Nucleus sampling parameter
47
+ top_k: Top-k sampling parameter
48
+ num_return_sequences: Number of sequences to generate
49
+ do_sample: Whether to use sampling or greedy decoding
50
+
51
+ Returns:
52
+ List of generated text strings
53
+ """
54
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
55
+
56
+ with torch.no_grad():
57
+ outputs = self.model.generate(
58
+ **inputs,
59
+ max_length=max_length,
60
+ temperature=temperature,
61
+ top_p=top_p,
62
+ top_k=top_k,
63
+ num_return_sequences=num_return_sequences,
64
+ do_sample=do_sample,
65
+ pad_token_id=self.tokenizer.pad_token_id
66
+ )
67
+
68
+ generated_texts = [
69
+ self.tokenizer.decode(output, skip_special_tokens=True)
70
+ for output in outputs
71
+ ]
72
+
73
+ return generated_texts
74
+
75
+ def batch_generate(
76
+ self,
77
+ prompts: List[str],
78
+ max_length: int = 100,
79
+ **kwargs
80
+ ) -> List[str]:
81
+ """
82
+ Generate text for multiple prompts in batch.
83
+
84
+ Args:
85
+ prompts: List of input prompts
86
+ max_length: Maximum length of generated text
87
+ **kwargs: Additional generation parameters
88
+
89
+ Returns:
90
+ List of generated text strings
91
+ """
92
+ inputs = self.tokenizer(
93
+ prompts,
94
+ return_tensors="pt",
95
+ padding=True,
96
+ truncation=True
97
+ ).to(self.device)
98
+
99
+ with torch.no_grad():
100
+ outputs = self.model.generate(
101
+ **inputs,
102
+ max_length=max_length,
103
+ pad_token_id=self.tokenizer.pad_token_id,
104
+ **kwargs
105
+ )
106
+
107
+ generated_texts = [
108
+ self.tokenizer.decode(output, skip_special_tokens=True)
109
+ for output in outputs
110
+ ]
111
+
112
+ return generated_texts
113
+
114
+
115
+ def main():
116
+ """Example usage of the inference script."""
117
+ model = KatGen1Inference()
118
+
119
+ prompt = "Once upon a time in a distant land,"
120
+ generated = model.generate(
121
+ prompt,
122
+ max_length=150,
123
+ temperature=0.8,
124
+ num_return_sequences=1
125
+ )
126
+
127
+ print("Generated text:")
128
+ print(generated[0])
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()