File size: 7,932 Bytes
bf64b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python3
"""

Inference script for Vortex models.

Supports both CUDA and MPS backends.

"""

import argparse
import sys
from pathlib import Path

import torch

from configs.vortex_7b_config import VORTEX_7B_CONFIG
from configs.vortex_13b_config import VORTEX_13B_CONFIG

from models.vortex_model import VortexModel
from tokenizer.vortex_tokenizer import VortexScienceTokenizer
from inference.cuda_optimize import optimize_for_cuda, profile_model
from inference.mps_optimize import optimize_for_mps, profile_model_mps


def parse_args():
    parser = argparse.ArgumentParser(description="Run inference with Vortex model")
    parser.add_argument("--model_path", type=str, required=True,
                        help="Path to trained model checkpoint")
    parser.add_argument("--config", type=str, default=None,
                        help="Path to model config (if not in checkpoint)")
    parser.add_argument("--tokenizer_path", type=str, default=None,
                        help="Path to tokenizer")
    parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
                        help="Model size for config")
    parser.add_argument("--device", type=str, default="cuda",
                        choices=["cuda", "mps", "cpu"],
                        help="Device to run on")
    parser.add_argument("--use_mps", action="store_true",
                        help="Use MPS backend (Apple Silicon)")
    parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
                        help="Apply quantization (CUDA only)")
    parser.add_argument("--flash_attention", action="store_true",
                        help="Use Flash Attention 2 (CUDA only)")
    parser.add_argument("--torch_compile", action="store_true",
                        help="Use torch.compile")
    parser.add_argument("--prompt", type=str, default=None,
                        help="Input prompt for generation")
    parser.add_argument("--interactive", action="store_true",
                        help="Run in interactive mode")
    parser.add_argument("--max_new_tokens", type=int, default=100,
                        help="Maximum new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.8,
                        help="Sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.9,
                        help="Top-p sampling")
    parser.add_argument("--profile", action="store_true",
                        help="Profile performance")
    return parser.parse_args()


def load_model(args):
    """Load model with appropriate optimizations."""
    # Load config
    if args.config:
        from configuration_vortex import VortexConfig
        config = VortexConfig.from_pretrained(args.config)
    else:
        # Use default config for size
        if args.model_size == "7b":
            config_dict = VORTEX_7B_CONFIG
        else:
            config_dict = VORTEX_13B_CONFIG
        from configuration_vortex import VortexConfig
        config = VortexConfig(**config_dict)

    # Create model
    print("Creating model...")
    model = VortexModel(config.to_dict())

    # Load checkpoint
    print(f"Loading checkpoint from {args.model_path}")
    checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=False)
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)
    print("Model loaded")

    # Apply optimizations
    device = torch.device(args.device)
    if args.use_mps or args.device == "mps":
        print("Optimizing for MPS...")
        model = optimize_for_mps(model, config.to_dict(), use_sdpa=True)
    else:
        print("Optimizing for CUDA...")
        model = optimize_for_cuda(
            model,
            config.to_dict(),
            use_flash_attention=args.flash_attention,
            use_torch_compile=args.torch_compile,
            quantization=args.quantization,
        )

    model = model.to(device)
    model.eval()

    return model, config


def load_tokenizer(args):
    """Load tokenizer."""
    tokenizer_path = args.tokenizer_path
    if not tokenizer_path:
        # Try to find in model directory
        model_dir = Path(args.model_path).parent
        tokenizer_path = model_dir / "vortex_tokenizer.json"

    if tokenizer_path and Path(tokenizer_path).exists():
        from tokenization_vortex import VortexTokenizer
        tokenizer = VortexTokenizer.from_pretrained(str(model_dir))
    else:
        print("Warning: No tokenizer found, using dummy tokenizer")
        class DummyTokenizer:
            def __call__(self, text, **kwargs):
                return {"input_ids": torch.tensor([[1, 2, 3]])}
            def decode(self, ids, **kwargs):
                return "dummy"
        tokenizer = DummyTokenizer()

    return tokenizer


def generate_text(model, tokenizer, prompt, args):
    """Generate text from prompt."""
    # Tokenize
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=False,
        truncation=True,
        max_length=model.config.max_seq_len - args.max_new_tokens,
    )
    input_ids = inputs["input_ids"].to(next(model.parameters()).device)

    # Generate
    with torch.no_grad():
        if hasattr(model, 'generate'):
            output_ids = model.generate(
                input_ids,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        else:
            # Manual generation
            for _ in range(args.max_new_tokens):
                outputs = model(input_ids)
                next_token_logits = outputs["logits"][:, -1, :]
                next_token = torch.multinomial(
                    torch.softmax(next_token_logits / args.temperature, dim=-1),
                    num_samples=1,
                )
                input_ids = torch.cat([input_ids, next_token], dim=-1)

                # Check for EOS
                if next_token.item() == tokenizer.eos_token_id:
                    break

    # Decode
    generated = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
    return generated


def main():
    args = parse_args()

    # Load model and tokenizer
    model, config = load_model(args)
    tokenizer = load_tokenizer(args)

    print(f"Model loaded on {next(model.parameters()).device}")
    print(f"Model parameters: {model.get_num_params():,}")

    # Profile if requested
    if args.profile:
        print("Profiling...")
        dummy_input = torch.randint(0, config.vocab_size, (1, 128)).to(next(model.parameters()).device)
        if args.use_mps or args.device == "mps":
            stats = profile_model_mps(model, dummy_input)
        else:
            stats = profile_model(model, dummy_input)
        print("Profile results:")
        for k, v in stats.items():
            print(f"  {k}: {v:.4f}")
        return

    # Interactive mode
    if args.interactive:
        print("Interactive mode. Type 'quit' to exit.")
        while True:
            prompt = input("\nPrompt: ")
            if prompt.lower() == "quit":
                break
            response = generate_text(model, tokenizer, prompt, args)
            print(f"\nResponse: {response}")
    elif args.prompt:
        response = generate_text(model, tokenizer, args.prompt, args)
        print(f"Response: {response}")
    else:
        print("No prompt provided. Use --prompt or --interactive.")


if __name__ == "__main__":
    main()