""" Inference script for Kat-Gen1 model """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Optional, List class KatGen1Inference: def __init__(self, model_name: str = "Katisim/Kat-Gen1", device: Optional[str] = None): """ Initialize the Kat-Gen1 model for inference. Args: model_name: HuggingFace model identifier device: Device to run inference on (cuda/cpu) """ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def generate( self, prompt: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, num_return_sequences: int = 1, do_sample: bool = True ) -> List[str]: """ Generate text from a prompt. Args: prompt: Input text prompt max_length: Maximum length of generated text temperature: Sampling temperature top_p: Nucleus sampling parameter top_k: Top-k sampling parameter num_return_sequences: Number of sequences to generate do_sample: Whether to use sampling or greedy decoding Returns: List of generated text strings """ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, top_k=top_k, num_return_sequences=num_return_sequences, do_sample=do_sample, pad_token_id=self.tokenizer.pad_token_id ) generated_texts = [ self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs ] return generated_texts def batch_generate( self, prompts: List[str], max_length: int = 100, **kwargs ) -> List[str]: """ Generate text for multiple prompts in batch. Args: prompts: List of input prompts max_length: Maximum length of generated text **kwargs: Additional generation parameters Returns: List of generated text strings """ inputs = self.tokenizer( prompts, return_tensors="pt", padding=True, truncation=True ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, pad_token_id=self.tokenizer.pad_token_id, **kwargs ) generated_texts = [ self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs ] return generated_texts def main(): """Example usage of the inference script.""" model = KatGen1Inference() prompt = "Once upon a time in a distant land," generated = model.generate( prompt, max_length=150, temperature=0.8, num_return_sequences=1 ) print("Generated text:") print(generated[0]) if __name__ == "__main__": main()