Kat-Gen1 / inference.py
Trouter-Library's picture
Create inference.py
33fc213 verified
"""
Inference script for Kat-Gen1 model
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional, List
class KatGen1Inference:
def __init__(self, model_name: str = "Katisim/Kat-Gen1", device: Optional[str] = None):
"""
Initialize the Kat-Gen1 model for inference.
Args:
model_name: HuggingFace model identifier
device: Device to run inference on (cuda/cpu)
"""
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def generate(
self,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
num_return_sequences: int = 1,
do_sample: bool = True
) -> List[str]:
"""
Generate text from a prompt.
Args:
prompt: Input text prompt
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
num_return_sequences: Number of sequences to generate
do_sample: Whether to use sampling or greedy decoding
Returns:
List of generated text strings
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id
)
generated_texts = [
self.tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
return generated_texts
def batch_generate(
self,
prompts: List[str],
max_length: int = 100,
**kwargs
) -> List[str]:
"""
Generate text for multiple prompts in batch.
Args:
prompts: List of input prompts
max_length: Maximum length of generated text
**kwargs: Additional generation parameters
Returns:
List of generated text strings
"""
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
pad_token_id=self.tokenizer.pad_token_id,
**kwargs
)
generated_texts = [
self.tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
return generated_texts
def main():
"""Example usage of the inference script."""
model = KatGen1Inference()
prompt = "Once upon a time in a distant land,"
generated = model.generate(
prompt,
max_length=150,
temperature=0.8,
num_return_sequences=1
)
print("Generated text:")
print(generated[0])
if __name__ == "__main__":
main()