Spaces:
Sleeping
Sleeping
| """ | |
| Example script for running inference with the Rose Beeper model. | |
| """ | |
| import torch | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Import the inference components (from the previous artifact) | |
| from beeper_inference import ( | |
| BeeperRoseGPT, | |
| BeeperIO, | |
| generate, | |
| get_default_config | |
| ) | |
| class BeeperInference: | |
| """Wrapper class for easy inference with the Rose Beeper model.""" | |
| def __init__(self, | |
| checkpoint_path: str = None, | |
| tokenizer_path: str = "beeper.tokenizer.json", | |
| device: str = None, | |
| hf_repo: str = "AbstractPhil/beeper-rose-v5"): | |
| """ | |
| Initialize the Beeper model for inference. | |
| Args: | |
| checkpoint_path: Path to local checkpoint file (.pt or .safetensors) | |
| tokenizer_path: Path to tokenizer file | |
| device: Device to run on ('cuda', 'cpu', or None for auto) | |
| hf_repo: HuggingFace repository to download from if no local checkpoint | |
| """ | |
| # Set device | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| print(f"Using device: {self.device}") | |
| # Load configuration | |
| self.config = get_default_config() | |
| # Initialize model | |
| self.model = BeeperRoseGPT(self.config).to(self.device) | |
| # Initialize pentachora banks | |
| cap_cfg = self.config.get("capoera", {}) | |
| # Using default sizes since we don't have the exact corpus info at inference | |
| self.model.ensure_pentachora( | |
| coarse_C=20, # Approximate number of datasets | |
| medium_C=int(cap_cfg.get("topic_bins", 512)), | |
| fine_C=int(cap_cfg.get("mood_bins", 7)), | |
| dim=self.config["dim"], | |
| device=self.device | |
| ) | |
| # Load weights | |
| self._load_weights(checkpoint_path, hf_repo) | |
| # Load tokenizer | |
| self._load_tokenizer(tokenizer_path, hf_repo) | |
| # Set to eval mode | |
| self.model.eval() | |
| def _load_weights(self, checkpoint_path: str, hf_repo: str): | |
| """Load model weights from local file or HuggingFace.""" | |
| loaded = False | |
| # Try local checkpoint first | |
| if checkpoint_path and os.path.exists(checkpoint_path): | |
| print(f"Loading weights from: {checkpoint_path}") | |
| missing, unexpected = BeeperIO.load_into_model( | |
| self.model, checkpoint_path, map_location=str(self.device), strict=False | |
| ) | |
| print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}") | |
| loaded = True | |
| # Try HuggingFace if no local checkpoint | |
| if not loaded and hf_repo: | |
| try: | |
| print(f"Downloading weights from HuggingFace: {hf_repo}") | |
| path = hf_hub_download(repo_id=hf_repo, filename="beeper_final.safetensors") | |
| missing, unexpected = BeeperIO.load_into_model( | |
| self.model, path, map_location=str(self.device), strict=False | |
| ) | |
| print(f"Loaded | missing={len(missing)} unexpected={len(unexpected)}") | |
| loaded = True | |
| except Exception as e: | |
| print(f"Failed to download from HuggingFace: {e}") | |
| if not loaded: | |
| print("WARNING: No weights loaded, using random initialization!") | |
| def _load_tokenizer(self, tokenizer_path: str, hf_repo: str): | |
| """Load tokenizer from local file or HuggingFace.""" | |
| if os.path.exists(tokenizer_path): | |
| print(f"Loading tokenizer from: {tokenizer_path}") | |
| self.tokenizer = Tokenizer.from_file(tokenizer_path) | |
| else: | |
| try: | |
| print(f"Downloading tokenizer from HuggingFace: {hf_repo}") | |
| path = hf_hub_download(repo_id=hf_repo, filename="tokenizer.json") | |
| self.tokenizer = Tokenizer.from_file(path) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load tokenizer: {e}") | |
| def generate_text(self, | |
| prompt: str, | |
| max_new_tokens: int = 120, | |
| temperature: float = 0.9, | |
| top_k: int = 40, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.1, | |
| presence_penalty: float = 0.6, | |
| frequency_penalty: float = 0.0) -> str: | |
| """ | |
| Generate text from a prompt. | |
| Args: | |
| prompt: Input text to continue from | |
| max_new_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature (0.1-2.0 typical) | |
| top_k: Top-k sampling (0 to disable) | |
| top_p: Nucleus sampling threshold (0.0-1.0) | |
| repetition_penalty: Penalty for repeated tokens | |
| presence_penalty: Penalty for tokens that have appeared | |
| frequency_penalty: Penalty based on token frequency | |
| Returns: | |
| Generated text string | |
| """ | |
| return generate( | |
| model=self.model, | |
| tok=self.tokenizer, | |
| cfg=self.config, | |
| prompt=prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| presence_penalty=presence_penalty, | |
| frequency_penalty=frequency_penalty, | |
| device=self.device, | |
| detokenize=True | |
| ) | |
| def batch_generate(self, prompts: list, **kwargs) -> list: | |
| """Generate text for multiple prompts.""" | |
| results = [] | |
| for prompt in prompts: | |
| results.append(self.generate_text(prompt, **kwargs)) | |
| return results | |
| def main(): | |
| """Example usage of the Beeper inference class.""" | |
| # Initialize the model | |
| print("Initializing Rose Beeper model...") | |
| beeper = BeeperInference( | |
| checkpoint_path=None, # Will download from HF | |
| device=None # Auto-select GPU if available | |
| ) | |
| # Example prompts | |
| prompts = [ | |
| "The robot went to school and", | |
| "Once upon a time in a distant galaxy,", | |
| "The meaning of life is", | |
| "In the beginning, there was", | |
| "The scientist discovered that", | |
| ] | |
| print("\n" + "="*60) | |
| print("GENERATING SAMPLES") | |
| print("="*60 + "\n") | |
| for prompt in prompts: | |
| print(f"Prompt: {prompt}") | |
| print("-" * 40) | |
| # Generate with different settings | |
| # Standard generation | |
| output = beeper.generate_text( | |
| prompt=prompt, | |
| max_new_tokens=100, | |
| temperature=0.9, | |
| top_k=40, | |
| top_p=0.9 | |
| ) | |
| print(f"Output: {output}") | |
| print() | |
| # More creative generation | |
| creative_output = beeper.generate_text( | |
| prompt=prompt, | |
| max_new_tokens=50, | |
| temperature=1.2, | |
| top_k=50, | |
| top_p=0.95, | |
| repetition_penalty=1.2 | |
| ) | |
| print(f"Creative: {creative_output}") | |
| print("\n" + "="*60 + "\n") | |
| if __name__ == "__main__": | |
| main() |