#!/usr/bin/env python3 """ Inference script for Vortex models. Supports both CUDA and MPS backends. """ import argparse import sys from pathlib import Path import torch from configs.vortex_7b_config import VORTEX_7B_CONFIG from configs.vortex_13b_config import VORTEX_13B_CONFIG from models.vortex_model import VortexModel from tokenizer.vortex_tokenizer import VortexScienceTokenizer from inference.cuda_optimize import optimize_for_cuda, profile_model from inference.mps_optimize import optimize_for_mps, profile_model_mps def parse_args(): parser = argparse.ArgumentParser(description="Run inference with Vortex model") parser.add_argument("--model_path", type=str, required=True, help="Path to trained model checkpoint") parser.add_argument("--config", type=str, default=None, help="Path to model config (if not in checkpoint)") parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer") parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b", help="Model size for config") parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "mps", "cpu"], help="Device to run on") parser.add_argument("--use_mps", action="store_true", help="Use MPS backend (Apple Silicon)") parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None, help="Apply quantization (CUDA only)") parser.add_argument("--flash_attention", action="store_true", help="Use Flash Attention 2 (CUDA only)") parser.add_argument("--torch_compile", action="store_true", help="Use torch.compile") parser.add_argument("--prompt", type=str, default=None, help="Input prompt for generation") parser.add_argument("--interactive", action="store_true", help="Run in interactive mode") parser.add_argument("--max_new_tokens", type=int, default=100, help="Maximum new tokens to generate") parser.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature") parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling") parser.add_argument("--profile", action="store_true", help="Profile performance") return parser.parse_args() def load_model(args): """Load model with appropriate optimizations.""" # Load config if args.config: from configuration_vortex import VortexConfig config = VortexConfig.from_pretrained(args.config) else: # Use default config for size if args.model_size == "7b": config_dict = VORTEX_7B_CONFIG else: config_dict = VORTEX_13B_CONFIG from configuration_vortex import VortexConfig config = VortexConfig(**config_dict) # Create model print("Creating model...") model = VortexModel(config.to_dict()) # Load checkpoint print(f"Loading checkpoint from {args.model_path}") checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) else: model.load_state_dict(checkpoint) print("Model loaded") # Apply optimizations device = torch.device(args.device) if args.use_mps or args.device == "mps": print("Optimizing for MPS...") model = optimize_for_mps(model, config.to_dict(), use_sdpa=True) else: print("Optimizing for CUDA...") model = optimize_for_cuda( model, config.to_dict(), use_flash_attention=args.flash_attention, use_torch_compile=args.torch_compile, quantization=args.quantization, ) model = model.to(device) model.eval() return model, config def load_tokenizer(args): """Load tokenizer.""" tokenizer_path = args.tokenizer_path if not tokenizer_path: # Try to find in model directory model_dir = Path(args.model_path).parent tokenizer_path = model_dir / "vortex_tokenizer.json" if tokenizer_path and Path(tokenizer_path).exists(): from tokenization_vortex import VortexTokenizer tokenizer = VortexTokenizer.from_pretrained(str(model_dir)) else: print("Warning: No tokenizer found, using dummy tokenizer") class DummyTokenizer: def __call__(self, text, **kwargs): return {"input_ids": torch.tensor([[1, 2, 3]])} def decode(self, ids, **kwargs): return "dummy" tokenizer = DummyTokenizer() return tokenizer def generate_text(model, tokenizer, prompt, args): """Generate text from prompt.""" # Tokenize inputs = tokenizer( prompt, return_tensors="pt", padding=False, truncation=True, max_length=model.config.max_seq_len - args.max_new_tokens, ) input_ids = inputs["input_ids"].to(next(model.parameters()).device) # Generate with torch.no_grad(): if hasattr(model, 'generate'): output_ids = model.generate( input_ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, do_sample=True, pad_token_id=tokenizer.pad_token_id, ) else: # Manual generation for _ in range(args.max_new_tokens): outputs = model(input_ids) next_token_logits = outputs["logits"][:, -1, :] next_token = torch.multinomial( torch.softmax(next_token_logits / args.temperature, dim=-1), num_samples=1, ) input_ids = torch.cat([input_ids, next_token], dim=-1) # Check for EOS if next_token.item() == tokenizer.eos_token_id: break # Decode generated = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True) return generated def main(): args = parse_args() # Load model and tokenizer model, config = load_model(args) tokenizer = load_tokenizer(args) print(f"Model loaded on {next(model.parameters()).device}") print(f"Model parameters: {model.get_num_params():,}") # Profile if requested if args.profile: print("Profiling...") dummy_input = torch.randint(0, config.vocab_size, (1, 128)).to(next(model.parameters()).device) if args.use_mps or args.device == "mps": stats = profile_model_mps(model, dummy_input) else: stats = profile_model(model, dummy_input) print("Profile results:") for k, v in stats.items(): print(f" {k}: {v:.4f}") return # Interactive mode if args.interactive: print("Interactive mode. Type 'quit' to exit.") while True: prompt = input("\nPrompt: ") if prompt.lower() == "quit": break response = generate_text(model, tokenizer, prompt, args) print(f"\nResponse: {response}") elif args.prompt: response = generate_text(model, tokenizer, args.prompt, args) print(f"Response: {response}") else: print("No prompt provided. Use --prompt or --interactive.") if __name__ == "__main__": main()