| """ | |
| Inference script for Kat-Gen1 model | |
| """ | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Optional, List | |
| class KatGen1Inference: | |
| def __init__(self, model_name: str = "Katisim/Kat-Gen1", device: Optional[str] = None): | |
| """ | |
| Initialize the Kat-Gen1 model for inference. | |
| Args: | |
| model_name: HuggingFace model identifier | |
| device: Device to run inference on (cuda/cpu) | |
| """ | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ).to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_length: int = 100, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| num_return_sequences: int = 1, | |
| do_sample: bool = True | |
| ) -> List[str]: | |
| """ | |
| Generate text from a prompt. | |
| Args: | |
| prompt: Input text prompt | |
| max_length: Maximum length of generated text | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| top_k: Top-k sampling parameter | |
| num_return_sequences: Number of sequences to generate | |
| do_sample: Whether to use sampling or greedy decoding | |
| Returns: | |
| List of generated text strings | |
| """ | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| num_return_sequences=num_return_sequences, | |
| do_sample=do_sample, | |
| pad_token_id=self.tokenizer.pad_token_id | |
| ) | |
| generated_texts = [ | |
| self.tokenizer.decode(output, skip_special_tokens=True) | |
| for output in outputs | |
| ] | |
| return generated_texts | |
| def batch_generate( | |
| self, | |
| prompts: List[str], | |
| max_length: int = 100, | |
| **kwargs | |
| ) -> List[str]: | |
| """ | |
| Generate text for multiple prompts in batch. | |
| Args: | |
| prompts: List of input prompts | |
| max_length: Maximum length of generated text | |
| **kwargs: Additional generation parameters | |
| Returns: | |
| List of generated text strings | |
| """ | |
| inputs = self.tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| **kwargs | |
| ) | |
| generated_texts = [ | |
| self.tokenizer.decode(output, skip_special_tokens=True) | |
| for output in outputs | |
| ] | |
| return generated_texts | |
| def main(): | |
| """Example usage of the inference script.""" | |
| model = KatGen1Inference() | |
| prompt = "Once upon a time in a distant land," | |
| generated = model.generate( | |
| prompt, | |
| max_length=150, | |
| temperature=0.8, | |
| num_return_sequences=1 | |
| ) | |
| print("Generated text:") | |
| print(generated[0]) | |
| if __name__ == "__main__": | |
| main() |