File size: 3,940 Bytes
1c70d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import argparse
from __init__ import QuillanSOTA, Config
import os

def load_model(checkpoint_path: str, device: str = 'cuda'):
    print(f"Loading model from {checkpoint_path}...")
    # Initialize model structure
    config = Config()
    model = QuillanSOTA(config)
    
    if os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(state_dict)
        print("Checkpoint loaded successfully.")
    else:
        print(f"Warning: Checkpoint {checkpoint_path} not found. Using random weights.")
    
    model.to(device)
    model.eval()
    return model

def generate_text(model, prompt_text, max_new_tokens=50, temperature=0.7, device='cuda'):
    """Simple text generation for demonstration"""
    print(f"Generating text with prompt: '{prompt_text}'")
    
    # Mock inputs with correct dimensions based on model expectations
    # The model uses patch_size=16, so we need dimensions that work with this
    batch_size = 1
    seq_len = 10
    
    # Calculate correct image dimensions
    patch_size = 16  # From Config
    # Model expects 65536 patches: sqrt(65536) = 256, so we need 256*16 = 4096x4096 image
    grid_size = 256
    img_size = grid_size * patch_size  # 4096x4096
    
    text_input = torch.randint(0, 1000, (batch_size, seq_len)).to(device)
    img_input = torch.randn(batch_size, 3, img_size, img_size).to(device)  # 4096x4096 for image (256x256 patches = 65536)
    aud_input = torch.randn(batch_size, 1, 1024).to(device)     # Audio length
    vid_input = torch.randn(batch_size, 3, 16, img_size, img_size).to(device)  # Video: 16 frames, 4096x4096
    
    with torch.no_grad():
        # Forward pass through the model
        outputs = model(text_input, img_input, aud_input, vid_input)
        
        # The model returns a dictionary with different modality outputs
        if isinstance(outputs, dict):
            text_logits = outputs.get('text', outputs.get('logits', None))
        else:
            text_logits = outputs
        
        # Simple sampling
        if text_logits is not None and isinstance(text_logits, torch.Tensor):
            if text_logits.dim() == 3:
                # Take the last token's logits
                last_logits = text_logits[0, -1, :]
                
                # Apply temperature
                scaled_logits = last_logits / temperature
                
                # Sample
                probs = torch.softmax(scaled_logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                
                return [next_token.item()]
            else:
                return [text_logits.argmax().item()]
        else:
            # Fallback: return a mock token
            return [42]  # Mock token ID

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Quillan Inference")
    parser.add_argument("--checkpoint", type=str, default="checkpoints/quillan_final.pt", help="Path to model checkpoint")
    parser.add_argument("--prompt", type=str, default="Hello, world!", help="Input prompt") # In real usage, need tokenizer
    parser.add_argument("--max_tokens", type=int, default=50, help="Max new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
    
    args = parser.parse_args()
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = load_model(args.checkpoint, device)
    
    # Generate text
    output_ids = generate_text(model, args.prompt, args.max_tokens, args.temperature, device)
    print(f"Generated token IDs: {output_ids}")
    print("Note: This is a demonstration with mock multimodal inputs.")
    print("For actual text generation, you'll need to implement proper tokenization and handle the multimodal outputs.")