File size: 2,139 Bytes
a683148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# extend_context.py - Extend MAP-NEO Mini context window to 4096 tokens
from model_neo import NeoMiniConfig, NeoMini
import torch

def extend_model_context(checkpoint_path="checkpoints/checkpoint_step_149999.pt", 

                        new_max_len=16384):
    """Extend model's context window from 2048 to 4096 tokens"""
    
    print(f"Extending context window to {new_max_len} tokens...")
    
    # Load original config and model
    config = NeoMiniConfig()
    config.max_seq_len = new_max_len  # Extend context window
    
    # Create new model with extended context
    extended_model = NeoMini(config)
    
    # Load original weights
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    original_state = checkpoint['model_state_dict']
    
    # Transfer weights (position embeddings will be interpolated)
    extended_state = extended_model.state_dict()
    
    for key in original_state:
        if key in extended_state:
            if 'pos' in key and extended_state[key].shape != original_state[key].shape:
                # Interpolate position embeddings for longer context
                print(f"Interpolating position embeddings: {key}")
                old_pos_emb = original_state[key]
                new_pos_emb = torch.nn.functional.interpolate(
                    old_pos_emb.unsqueeze(0).unsqueeze(0),
                    size=(new_max_len, old_pos_emb.shape[-1]),
                    mode='linear'
                ).squeeze(0).squeeze(0)
                extended_state[key] = new_pos_emb
            else:
                extended_state[key] = original_state[key]
    
    extended_model.load_state_dict(extended_state)
    
    # Save extended model
    extended_checkpoint = {
        'model_state_dict': extended_model.state_dict(),
        'config': config.to_dict()
    }
    
    output_path = "checkpoints/extended_context_model.pt"
    torch.save(extended_checkpoint, output_path)
    print(f"Extended model saved to {output_path}")
    
    return extended_model, config

if __name__ == "__main__":
    extend_model_context()