sigmoidneuron123 commited on
Commit
2fd8fb4
·
verified ·
1 Parent(s): 05e84f3

MLX-compatibility

Browse files
Files changed (2) hide show
  1. AppleAI-converter.py +72 -0
  2. chessy_model_mlx.npz +3 -0
AppleAI-converter.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mlx as mx
2
+ import mlx.nn as mx_nn
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+ device = torch.device('mps')
8
+
9
+ CONFIG = {
10
+ "model_path": "chessy_model.pth",
11
+ "backup_model_path": "chessy_modelt-1.pth",
12
+ }
13
+
14
+ class NN1(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.embedding = nn.Embedding(13, 64)
18
+ self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16)
19
+ self.neu = 512
20
+ self.neurons = nn.Sequential(
21
+ nn.Linear(4096, self.neu),
22
+ nn.ReLU(),
23
+ nn.Linear(self.neu, self.neu),
24
+ nn.ReLU(),
25
+ nn.Linear(self.neu, self.neu),
26
+ nn.ReLU(),
27
+ nn.Linear(self.neu, self.neu),
28
+ nn.ReLU(),
29
+ nn.Linear(self.neu, self.neu),
30
+ nn.ReLU(),
31
+ nn.Linear(self.neu, self.neu),
32
+ nn.ReLU(),
33
+ nn.Linear(self.neu, self.neu),
34
+ nn.ReLU(),
35
+ nn.Linear(self.neu, self.neu),
36
+ nn.ReLU(),
37
+ nn.Linear(self.neu, self.neu),
38
+ nn.ReLU(),
39
+ nn.Linear(self.neu, self.neu),
40
+ nn.ReLU(),
41
+ nn.Linear(self.neu, self.neu),
42
+ nn.ReLU(),
43
+ nn.Linear(self.neu, self.neu),
44
+ nn.ReLU(),
45
+ nn.Linear(self.neu, self.neu),
46
+ nn.ReLU(),
47
+ nn.Linear(self.neu, 64),
48
+ nn.ReLU(),
49
+ nn.Linear(64, 4)
50
+ )
51
+
52
+ def forward(self, x):
53
+ x = self.embedding(x)
54
+ x = x.permute(1, 0, 2)
55
+ attn_output, _ = self.attention(x, x, x)
56
+ x = attn_output.permute(1, 0, 2).contiguous()
57
+ x = x.view(x.size(0), -1)
58
+ x = self.neurons(x)
59
+ return x
60
+
61
+ model = NN1().to(device)
62
+ try:
63
+ model.load_state_dict(torch.load(CONFIG['model_path'], map_location=device))
64
+ print(f"Loaded model from {CONFIG['model_path']}")
65
+ except FileNotFoundError:
66
+ try:
67
+ model.load_state_dict(torch.load(CONFIG["backup_model_path"], map_location=device))
68
+ print(f"Loaded backup model from {CONFIG['backup_model_path']}")
69
+ except FileNotFoundError:
70
+ print("No model file found, starting from scratch.")
71
+ weights = {k: v.detach().cpu().numpy() for k, v in model.state_dict().items()}
72
+ np.savez("chessy_model_mlx.npz", **weights)
chessy_model_mlx.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b1c4066b547e843021ac6ada1da8dbf1191cc35e3d51a4ae46da65b58efb5e2
3
+ size 21209702