File size: 3,852 Bytes
33fc213 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""
Inference script for Kat-Gen1 model
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional, List
class KatGen1Inference:
def __init__(self, model_name: str = "Katisim/Kat-Gen1", device: Optional[str] = None):
"""
Initialize the Kat-Gen1 model for inference.
Args:
model_name: HuggingFace model identifier
device: Device to run inference on (cuda/cpu)
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def generate(
self,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
num_return_sequences: int = 1,
do_sample: bool = True
) -> List[str]:
"""
Generate text from a prompt.
Args:
prompt: Input text prompt
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
num_return_sequences: Number of sequences to generate
do_sample: Whether to use sampling or greedy decoding
Returns:
List of generated text strings
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id
)
generated_texts = [
self.tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
return generated_texts
def batch_generate(
self,
prompts: List[str],
max_length: int = 100,
**kwargs
) -> List[str]:
"""
Generate text for multiple prompts in batch.
Args:
prompts: List of input prompts
max_length: Maximum length of generated text
**kwargs: Additional generation parameters
Returns:
List of generated text strings
"""
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
pad_token_id=self.tokenizer.pad_token_id,
**kwargs
)
generated_texts = [
self.tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
return generated_texts
def main():
"""Example usage of the inference script."""
model = KatGen1Inference()
prompt = "Once upon a time in a distant land,"
generated = model.generate(
prompt,
max_length=150,
temperature=0.8,
num_return_sequences=1
)
print("Generated text:")
print(generated[0])
if __name__ == "__main__":
main() |