| |
|
| | """
|
| | Inference script for Vortex models.
|
| | Supports both CUDA and MPS backends.
|
| | """
|
| |
|
| | import argparse
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| | import torch
|
| |
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| | from configs.vortex_13b_config import VORTEX_13B_CONFIG
|
| |
|
| | from models.vortex_model import VortexModel
|
| | from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| | from inference.cuda_optimize import optimize_for_cuda, profile_model
|
| | from inference.mps_optimize import optimize_for_mps, profile_model_mps
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(description="Run inference with Vortex model")
|
| | parser.add_argument("--model_path", type=str, required=True,
|
| | help="Path to trained model checkpoint")
|
| | parser.add_argument("--config", type=str, default=None,
|
| | help="Path to model config (if not in checkpoint)")
|
| | parser.add_argument("--tokenizer_path", type=str, default=None,
|
| | help="Path to tokenizer")
|
| | parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
|
| | help="Model size for config")
|
| | parser.add_argument("--device", type=str, default="cuda",
|
| | choices=["cuda", "mps", "cpu"],
|
| | help="Device to run on")
|
| | parser.add_argument("--use_mps", action="store_true",
|
| | help="Use MPS backend (Apple Silicon)")
|
| | parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
|
| | help="Apply quantization (CUDA only)")
|
| | parser.add_argument("--flash_attention", action="store_true",
|
| | help="Use Flash Attention 2 (CUDA only)")
|
| | parser.add_argument("--torch_compile", action="store_true",
|
| | help="Use torch.compile")
|
| | parser.add_argument("--prompt", type=str, default=None,
|
| | help="Input prompt for generation")
|
| | parser.add_argument("--interactive", action="store_true",
|
| | help="Run in interactive mode")
|
| | parser.add_argument("--max_new_tokens", type=int, default=100,
|
| | help="Maximum new tokens to generate")
|
| | parser.add_argument("--temperature", type=float, default=0.8,
|
| | help="Sampling temperature")
|
| | parser.add_argument("--top_p", type=float, default=0.9,
|
| | help="Top-p sampling")
|
| | parser.add_argument("--profile", action="store_true",
|
| | help="Profile performance")
|
| | return parser.parse_args()
|
| |
|
| |
|
| | def load_model(args):
|
| | """Load model with appropriate optimizations."""
|
| |
|
| | if args.config:
|
| | from configuration_vortex import VortexConfig
|
| | config = VortexConfig.from_pretrained(args.config)
|
| | else:
|
| |
|
| | if args.model_size == "7b":
|
| | config_dict = VORTEX_7B_CONFIG
|
| | else:
|
| | config_dict = VORTEX_13B_CONFIG
|
| | from configuration_vortex import VortexConfig
|
| | config = VortexConfig(**config_dict)
|
| |
|
| |
|
| | print("Creating model...")
|
| | model = VortexModel(config.to_dict())
|
| |
|
| |
|
| | print(f"Loading checkpoint from {args.model_path}")
|
| | checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False)
|
| | if "model_state_dict" in checkpoint:
|
| | model.load_state_dict(checkpoint["model_state_dict"])
|
| | else:
|
| | model.load_state_dict(checkpoint)
|
| | print("Model loaded")
|
| |
|
| |
|
| | device = torch.device(args.device)
|
| | if args.use_mps or args.device == "mps":
|
| | print("Optimizing for MPS...")
|
| | model = optimize_for_mps(model, config.to_dict(), use_sdpa=True)
|
| | else:
|
| | print("Optimizing for CUDA...")
|
| | model = optimize_for_cuda(
|
| | model,
|
| | config.to_dict(),
|
| | use_flash_attention=args.flash_attention,
|
| | use_torch_compile=args.torch_compile,
|
| | quantization=args.quantization,
|
| | )
|
| |
|
| | model = model.to(device)
|
| | model.eval()
|
| |
|
| | return model, config
|
| |
|
| |
|
| | def load_tokenizer(args):
|
| | """Load tokenizer."""
|
| | tokenizer_path = args.tokenizer_path
|
| | if not tokenizer_path:
|
| |
|
| | model_dir = Path(args.model_path).parent
|
| | tokenizer_path = model_dir / "vortex_tokenizer.json"
|
| |
|
| | if tokenizer_path and Path(tokenizer_path).exists():
|
| | from tokenization_vortex import VortexTokenizer
|
| | tokenizer = VortexTokenizer.from_pretrained(str(model_dir))
|
| | else:
|
| | print("Warning: No tokenizer found, using dummy tokenizer")
|
| | class DummyTokenizer:
|
| | def __call__(self, text, **kwargs):
|
| | return {"input_ids": torch.tensor([[1, 2, 3]])}
|
| | def decode(self, ids, **kwargs):
|
| | return "dummy"
|
| | tokenizer = DummyTokenizer()
|
| |
|
| | return tokenizer
|
| |
|
| |
|
| | def generate_text(model, tokenizer, prompt, args):
|
| | """Generate text from prompt."""
|
| |
|
| | inputs = tokenizer(
|
| | prompt,
|
| | return_tensors="pt",
|
| | padding=False,
|
| | truncation=True,
|
| | max_length=model.config.max_seq_len - args.max_new_tokens,
|
| | )
|
| | input_ids = inputs["input_ids"].to(next(model.parameters()).device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | if hasattr(model, 'generate'):
|
| | output_ids = model.generate(
|
| | input_ids,
|
| | max_new_tokens=args.max_new_tokens,
|
| | temperature=args.temperature,
|
| | top_p=args.top_p,
|
| | do_sample=True,
|
| | pad_token_id=tokenizer.pad_token_id,
|
| | )
|
| | else:
|
| |
|
| | for _ in range(args.max_new_tokens):
|
| | outputs = model(input_ids)
|
| | next_token_logits = outputs["logits"][:, -1, :]
|
| | next_token = torch.multinomial(
|
| | torch.softmax(next_token_logits / args.temperature, dim=-1),
|
| | num_samples=1,
|
| | )
|
| | input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| |
|
| |
|
| | if next_token.item() == tokenizer.eos_token_id:
|
| | break
|
| |
|
| |
|
| | generated = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
|
| | return generated
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| |
|
| |
|
| | model, config = load_model(args)
|
| | tokenizer = load_tokenizer(args)
|
| |
|
| | print(f"Model loaded on {next(model.parameters()).device}")
|
| | print(f"Model parameters: {model.get_num_params():,}")
|
| |
|
| |
|
| | if args.profile:
|
| | print("Profiling...")
|
| | dummy_input = torch.randint(0, config.vocab_size, (1, 128)).to(next(model.parameters()).device)
|
| | if args.use_mps or args.device == "mps":
|
| | stats = profile_model_mps(model, dummy_input)
|
| | else:
|
| | stats = profile_model(model, dummy_input)
|
| | print("Profile results:")
|
| | for k, v in stats.items():
|
| | print(f" {k}: {v:.4f}")
|
| | return
|
| |
|
| |
|
| | if args.interactive:
|
| | print("Interactive mode. Type 'quit' to exit.")
|
| | while True:
|
| | prompt = input("\nPrompt: ")
|
| | if prompt.lower() == "quit":
|
| | break
|
| | response = generate_text(model, tokenizer, prompt, args)
|
| | print(f"\nResponse: {response}")
|
| | elif args.prompt:
|
| | response = generate_text(model, tokenizer, args.prompt, args)
|
| | print(f"Response: {response}")
|
| | else:
|
| | print("No prompt provided. Use --prompt or --interactive.")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|