samclane commited on
Commit
efd4cc5
·
verified ·
1 Parent(s): a4906a5

Upload username_transformer.py

Browse files
Files changed (1) hide show
  1. username_transformer.py +539 -0
username_transformer.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Username_Transformer
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1iae8ZzCuKYOPmMyTibAh7hVzwjbrW4Pe
8
+ """
9
+
10
+ # Commented out IPython magic to ensure Python compatibility.
11
+ # Install PyTorch
12
+ # %pip install torch torchvision torchaudio
13
+
14
+ # Install other dependencies
15
+ # %pip install numpy pandas nltk elevenlabs requests
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.utils.data import Dataset, DataLoader
20
+ import numpy as np
21
+ import nltk
22
+ import re
23
+ from collections import Counter
24
+ from tqdm import tqdm
25
+ import requests
26
+ from nltk.corpus import cmudict
27
+ import os
28
+ import pandas as pd
29
+
30
+ # allow cuDNN benchmark to pick fastest model
31
+ import torch.backends.cudnn as cudnn
32
+ cudnn.benchmark = True
33
+
34
+ nltk.download('cmudict')
35
+
36
+ cmu_dict = cmudict.dict()
37
+
38
+ url = "https://raw.githubusercontent.com/danielmiessler/SecLists/master/Usernames/xato-net-10-million-usernames.txt"
39
+
40
+ try:
41
+ response = requests.get(url)
42
+ response.raise_for_status() # Raise an exception for bad status codes
43
+
44
+ usernames = response.text.splitlines()
45
+ print(f"Downloaded {len(usernames)} usernames.")
46
+
47
+ except requests.exceptions.RequestException as e:
48
+ print(f"Error downloading usernames: {e}")
49
+ usernames = []
50
+
51
+ def normalize_username(username):
52
+ # Convert to lowercase
53
+ username = username.lower()
54
+ # Replace numbers with words
55
+ num_to_word = {
56
+ '0': ' zero ', '1': ' one ', '2': ' two ', '3': ' three ',
57
+ '4': ' four ', '5': ' five ', '6': ' six ', '7': ' seven ',
58
+ '8': ' eight ', '9': ' nine '
59
+ }
60
+ for num, word in num_to_word.items():
61
+ username = username.replace(num, word)
62
+ # Replace special characters with spaces
63
+ username = re.sub(r'[\W_]+', ' ', username)
64
+ # Remove extra spaces
65
+ username = re.sub(r'\s+', ' ', username).strip()
66
+ return username
67
+
68
+ def get_phonemes(word):
69
+ phonemes_list = cmu_dict.get(word)
70
+ if phonemes_list:
71
+ return phonemes_list[0] # Use the first pronunciation
72
+ else:
73
+ return None # Only show usernames that have correct phonemes
74
+
75
+ def username_to_phonemes(username):
76
+ normalized = normalize_username(username)
77
+ words = normalized.split()
78
+ phonemes = []
79
+ for word in words:
80
+ phoneme = get_phonemes(word)
81
+ if phoneme:
82
+ phonemes.extend(phoneme)
83
+ # else:
84
+ # print(f"Warning: Unable to find phonemes for word: {word}")
85
+ return phonemes
86
+
87
+ input_sequences = []
88
+ target_sequences = []
89
+
90
+ for username in usernames:
91
+ input_seq = list(normalize_username(username))
92
+ target_seq = username_to_phonemes(username)
93
+ if target_seq:
94
+ input_sequences.append(input_seq)
95
+ target_sequences.append(target_seq)
96
+
97
+ # Character Vocabulary
98
+ char_counter = Counter([char for seq in input_sequences for char in seq])
99
+ char_list = ['<pad>'] + sorted(char_counter.keys())
100
+ char_vocab = {char: idx for idx, char in enumerate(char_list)}
101
+
102
+ # Phoneme Vocabulary
103
+ phoneme_counter = Counter([phoneme for seq in target_sequences for phoneme in seq])
104
+ phoneme_list = ['<pad>', '<sos>', '<eos>'] + sorted(phoneme_counter.keys())
105
+ phoneme_vocab = {phoneme: idx for idx, phoneme in enumerate(phoneme_list)}
106
+
107
+ def encode_sequence(seq, vocab, max_len, add_special_tokens=False):
108
+ encoded = [vocab.get(token, vocab['<pad>']) for token in seq]
109
+ if add_special_tokens:
110
+ encoded = [vocab['<sos>']] + encoded + [vocab['<eos>']]
111
+ # Trim or pad the sequence to max_len
112
+ encoded = encoded[:max_len] + [vocab['<pad>']] * max(0, max_len - len(encoded))
113
+ return encoded
114
+
115
+
116
+ max_input_len = max(len(seq) for seq in input_sequences)
117
+ max_target_len = max(len(seq) for seq in target_sequences) + 2 # For <sos> and <eos>
118
+
119
+ encoded_inputs = [encode_sequence(seq, char_vocab, max_input_len) for seq in input_sequences]
120
+ encoded_targets = [encode_sequence(seq, phoneme_vocab, max_target_len, True) for seq in target_sequences]
121
+
122
+ class UsernameDataset(Dataset):
123
+ def __init__(self, inputs, targets):
124
+ self.inputs = torch.tensor(inputs, dtype=torch.long)
125
+ self.targets = torch.tensor(targets, dtype=torch.long)
126
+
127
+ def __len__(self):
128
+ return len(self.inputs)
129
+
130
+ def __getitem__(self, idx):
131
+ return self.inputs[idx], self.targets[idx]
132
+
133
+ dataset = UsernameDataset(encoded_inputs, encoded_targets)
134
+ data_loader = DataLoader(dataset, batch_size=512, shuffle=True)
135
+
136
+ # Function to decode sequences
137
+ def decode_sequence(encoded_seq, vocab):
138
+ idx_to_token = {idx: token for token, idx in vocab.items()}
139
+ decoded_seq = [idx_to_token.get(idx, '<unk>') for idx in encoded_seq]
140
+ return decoded_seq
141
+
142
+ # Create lists to store decoded usernames and pronunciations
143
+ usernames = []
144
+ pronunciations = []
145
+
146
+ # Iterate through the dataset and decode sequences
147
+ for input_seq, target_seq in dataset:
148
+ username = ''.join(decode_sequence(input_seq.tolist(), char_vocab))
149
+ pronunciation = ' '.join(decode_sequence(target_seq.tolist(), phoneme_vocab))
150
+ usernames.append(username)
151
+ pronunciations.append(pronunciation)
152
+
153
+ # Create a Pandas DataFrame
154
+ df = pd.DataFrame({'username': usernames, 'pronunciation': pronunciations})
155
+
156
+ # Export to CSV
157
+ df.to_csv('username_pronunciation.csv', index=False)
158
+
159
+ class Encoder(nn.Module):
160
+ def __init__(self, input_dim, emb_dim, hid_dim):
161
+ super().__init__()
162
+ self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=char_vocab['<pad>'])
163
+ self.gru = nn.GRU(emb_dim, hid_dim, batch_first=True)
164
+
165
+ def forward(self, src):
166
+ embedded = self.embedding(src)
167
+ outputs, hidden = self.gru(embedded)
168
+ return outputs, hidden
169
+
170
+ class Attention(nn.Module):
171
+ def __init__(self, hid_dim):
172
+ super().__init__()
173
+ self.attn = nn.Linear(hid_dim * 2, hid_dim)
174
+ self.v = nn.Linear(hid_dim, 1, bias=False)
175
+
176
+ def forward(self, hidden, encoder_outputs):
177
+ src_len = encoder_outputs.shape[1]
178
+ hidden = hidden.repeat(1, src_len, 1)
179
+ energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
180
+ attention = self.v(energy).squeeze(2)
181
+ return torch.softmax(attention, dim=1)
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(self, output_dim, emb_dim, hid_dim, attention):
185
+ super().__init__()
186
+ self.output_dim = output_dim
187
+ self.attention = attention
188
+ self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=phoneme_vocab['<pad>'])
189
+ self.gru = nn.GRU(emb_dim + hid_dim, hid_dim, batch_first=True)
190
+ self.fc_out = nn.Linear(hid_dim * 2, output_dim)
191
+
192
+ def forward(self, input, hidden, encoder_outputs):
193
+ input = input.unsqueeze(1)
194
+ embedded = self.embedding(input)
195
+ a = self.attention(hidden.permute(1, 0, 2), encoder_outputs)
196
+ a = a.unsqueeze(1)
197
+ weighted = torch.bmm(a, encoder_outputs)
198
+ rnn_input = torch.cat((embedded, weighted), dim=2)
199
+ output, hidden = self.gru(rnn_input, hidden)
200
+ output = torch.cat((output.squeeze(1), weighted.squeeze(1)), dim=1)
201
+ prediction = self.fc_out(output)
202
+ return prediction, hidden
203
+
204
+ class Seq2Seq(nn.Module):
205
+ def __init__(self, encoder, decoder, device):
206
+ super().__init__()
207
+ self.encoder = encoder
208
+ self.decoder = decoder
209
+ self.device = device
210
+
211
+ def forward(self, src, trg, teacher_forcing_ratio=0.5):
212
+ batch_size = src.shape[0]
213
+ trg_len = trg.shape[1]
214
+ trg_vocab_size = self.decoder.output_dim
215
+
216
+ outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
217
+ encoder_outputs, hidden = self.encoder(src)
218
+ input = trg[:, 0]
219
+
220
+ for t in range(1, trg_len):
221
+ output, hidden = self.decoder(input, hidden, encoder_outputs)
222
+ outputs[:, t] = output
223
+ top1 = output.argmax(1)
224
+ teacher_force = np.random.random() < teacher_forcing_ratio
225
+ input = trg[:, t] if teacher_force else top1
226
+ return outputs
227
+
228
+ def get_latest_checkpoint(directory):
229
+ # Get a list of all files in the directory
230
+ files = os.listdir(directory)
231
+
232
+ # Filter the list to only include g2p{n}.pth files
233
+ checkpoint_files = [f for f in files if re.match(r'g2p\d+\.pth', f)]
234
+
235
+ # Extract the numbers from the filenames
236
+ checkpoint_numbers = [int(re.search(r'g2p(\d+)\.pth', f).group(1)) for f in checkpoint_files]
237
+ print(checkpoint_numbers)
238
+
239
+ # Sort the files by their numbers
240
+ sorted_files = sorted(zip(checkpoint_numbers, checkpoint_files))
241
+
242
+ # Get the latest file (last element in the sorted list)
243
+ if sorted_files:
244
+ latest_file = sorted_files[-1][1]
245
+ latest_checkpoint_path = os.path.join(directory, latest_file)
246
+ return latest_checkpoint_path
247
+ else:
248
+ return None
249
+
250
+ def get_next_version(directory):
251
+ files = os.listdir(directory)
252
+
253
+ # Filter the list to only include g2p{n}.pth files
254
+ checkpoint_files = [f for f in files if re.match(r'g2p\d+\.pth', f)]
255
+
256
+ # Extract the numbers from the filenames
257
+ checkpoint_numbers = [int(re.search(r'g2p(\d+)\.pth', f).group(1)) for f in checkpoint_files]
258
+ print(checkpoint_numbers)
259
+
260
+ # Sort the files by their numbers
261
+ sorted_files = sorted(zip(checkpoint_numbers, checkpoint_files))
262
+ if sorted_files:
263
+ latest_version = sorted_files[-1][0]
264
+ print(f"Latest version: {sorted_files[-1]}")
265
+ return latest_version + 1
266
+ else:
267
+ return 1 # Start with version 1 if no checkpoints exist
268
+
269
+ def save_checkpoint(model, directory, version):
270
+ filename = f"g2p{version}.pth"
271
+ filepath = os.path.join(directory, filename)
272
+ torch.save(model.state_dict(), filepath)
273
+ print(f"Model saved to {filepath}")
274
+
275
+ # Get the latest checkpoint file path
276
+ directory = '/content/drive/MyDrive/AI/username_g2p/'
277
+ latest_checkpoint_file = get_latest_checkpoint(directory)
278
+
279
+ if latest_checkpoint_file:
280
+ print(f"Latest checkpoint file: {latest_checkpoint_file}")
281
+ else:
282
+ print("No checkpoint files found.")
283
+
284
+ print(get_next_version(directory))
285
+
286
+ INPUT_DIM = len(char_vocab)
287
+ OUTPUT_DIM = len(phoneme_vocab)
288
+ ENC_EMB_DIM = 64
289
+ DEC_EMB_DIM = 64
290
+ HID_DIM = 128
291
+
292
+ attn = Attention(HID_DIM)
293
+ enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM)
294
+ dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, attn)
295
+
296
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
297
+ model = Seq2Seq(enc, dec, device).to(device)
298
+ optimizer = torch.optim.Adam(model.parameters())
299
+ criterion = nn.CrossEntropyLoss(ignore_index=phoneme_vocab['<pad>'])
300
+
301
+ # Path to your checkpoint file
302
+ checkpoint_file = latest_checkpoint_file if latest_checkpoint_file else 'g2p1.pth'
303
+
304
+ # Check if the checkpoint file exists
305
+ if os.path.exists(checkpoint_file):
306
+ # Load the checkpoint
307
+ print(f"Loading checkpoint from {checkpoint_file}")
308
+ model.load_state_dict(torch.load(checkpoint_file))
309
+ else:
310
+ print(f"Checkpoint file not found. Using default initialization.")
311
+
312
+ print(device)
313
+
314
+ # Verify input sequences
315
+ max_input_idx = max([max(seq) for seq in encoded_inputs])
316
+ print(f'Max input index: {max_input_idx}, Input vocab size: {INPUT_DIM}')
317
+
318
+ # Verify target sequences
319
+ max_target_idx = max([max(seq) for seq in encoded_targets])
320
+ print(f'Max target index: {max_target_idx}, Output vocab size: {OUTPUT_DIM}')
321
+
322
+ def train(model, loader, optimizer, criterion, clip):
323
+ model.train()
324
+ epoch_loss = 0
325
+
326
+ for src, trg in tqdm(loader, desc="Training Batches"):
327
+ src, trg = src.to(device), trg.to(device)
328
+ optimizer.zero_grad()
329
+ output = model(src, trg)
330
+ output_dim = output.shape[-1]
331
+ output = output[:, 1:].reshape(-1, output_dim)
332
+ trg = trg[:, 1:].reshape(-1)
333
+ loss = criterion(output, trg)
334
+ loss.backward()
335
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
336
+ optimizer.step()
337
+ epoch_loss += loss.item()
338
+
339
+ return epoch_loss / len(loader)
340
+
341
+ N_EPOCHS = 1
342
+ CLIP = 1
343
+
344
+ for epoch in range(N_EPOCHS):
345
+ loss = train(model, data_loader, optimizer, criterion, CLIP)
346
+ print(f'Epoch: {epoch+1}, Loss: {loss:.4f}')
347
+
348
+ # Get the next version number
349
+ next_version = get_next_version(directory)
350
+
351
+ # Save the model with the new version number
352
+ save_checkpoint(model, directory, next_version)
353
+
354
+ def predict(model, username):
355
+ model.eval()
356
+ with torch.no_grad():
357
+ normalized = normalize_username(username)
358
+ input_seq = encode_sequence(list(normalized), char_vocab, max_input_len)
359
+ src = torch.tensor([input_seq], dtype=torch.long).to(device)
360
+ encoder_outputs, hidden = model.encoder(src)
361
+ input_token = torch.tensor([phoneme_vocab['<sos>']], dtype=torch.long).to(device)
362
+ outputs = []
363
+
364
+ for _ in range(max_target_len):
365
+ output, hidden = model.decoder(input_token, hidden, encoder_outputs)
366
+ top1 = output.argmax(1)
367
+ if top1.item() == phoneme_vocab['<eos>']:
368
+ break
369
+ outputs.append(top1.item())
370
+ input_token = top1
371
+
372
+ idx_to_phoneme = {idx: phoneme for phoneme, idx in phoneme_vocab.items()}
373
+ predicted_phonemes = [idx_to_phoneme[idx] for idx in outputs]
374
+ return ' '.join(predicted_phonemes)
375
+
376
+ test_username = 'supercalafragalisticexpialadocous'
377
+ test_username = 'barnabassacket'
378
+ pronunciation = predict(model, test_username)
379
+ print(f'Username: {test_username}')
380
+ print(f'Pronunciation: {pronunciation}')
381
+
382
+ # from https://github.com/margonaut/CMU-to-IPA-Converter/blob/master/cmu_ipa_mapping.rb
383
+ CMU_IPA_MAPPING = {
384
+ "B": "b",
385
+ "CH": "ʧ",
386
+ "D": "d",
387
+ "DH": "ð",
388
+ "F": "f",
389
+ "G": "g",
390
+ "HH": "h",
391
+ "JH": "ʤ",
392
+ "K": "k",
393
+ "L": "l",
394
+ "M": "m",
395
+ "N": "n",
396
+ "NG": "ŋ",
397
+ "P": "p",
398
+ "R": "r",
399
+ "S": "s",
400
+ "SH": "ʃ",
401
+ "T": "t",
402
+ "TH": "θ",
403
+ "V": "v",
404
+ "W": "w",
405
+ "Y": "j",
406
+ "Z": "z",
407
+ "ZH": "ʒ",
408
+ "AA0": "ɑ",
409
+ "AA1": "ɑ",
410
+ "AA2": "ɑ",
411
+ "AE0": "æ",
412
+ "AE1": "æ",
413
+ "AE2": "æ",
414
+ "AH0": "ə",
415
+ "AH1": "ʌ",
416
+ "AH2": "ʌ",
417
+ "AO0": "ɔ",
418
+ "AO1": "ɔ",
419
+ "AO2": "ɔ",
420
+ "EH0": "ɛ",
421
+ "EH1": "ɛ",
422
+ "EH2": "ɛ",
423
+ "ER0": "ɚ",
424
+ "ER1": "ɝ",
425
+ "ER2": "ɝ",
426
+ "IH0": "ɪ",
427
+ "IH1": "ɪ",
428
+ "IH2": "ɪ",
429
+ "IY0": "i",
430
+ "IY1": "i",
431
+ "IY2": "i",
432
+ "UH0": "ʊ",
433
+ "UH1": "ʊ",
434
+ "UH2": "ʊ",
435
+ "UW0": "u",
436
+ "UW1": "u",
437
+ "UW2": "u",
438
+ "AW0": "aʊ",
439
+ "AW1": "aʊ",
440
+ "AW2": "aʊ",
441
+ "AY0": "aɪ",
442
+ "AY1": "aɪ",
443
+ "AY2": "aɪ",
444
+ "EY0": "eɪ",
445
+ "EY1": "eɪ",
446
+ "EY2": "eɪ",
447
+ "OW0": "oʊ",
448
+ "OW1": "oʊ",
449
+ "OW2": "oʊ",
450
+ "OY0": "ɔɪ",
451
+ "OY1": "ɔɪ",
452
+ "OY2": "ɔɪ"
453
+ }
454
+
455
+ pronunciation = predict(model, test_username)
456
+ ipa_sequence = ''.join([CMU_IPA_MAPPING.get(phoneme, phoneme) for phoneme in pronunciation.split()])
457
+ print(f'Username: {test_username}')
458
+ print(f'Pronunciation: {ipa_sequence}')
459
+
460
+ ssml_template = """<phoneme alphabet="{alphabet}" ph="{phonetics}">{text}</phoneme>"""
461
+
462
+ class Alphabets:
463
+ IPA = "ipa"
464
+ CMU = "cmu-arpabet"
465
+
466
+ print(ssml_template.format(alphabet=Alphabets.IPA, phonetics="ˈæktʃuəli", text="actually"))
467
+
468
+ from google.colab import userdata
469
+ eleven_labs_key = userdata.get('ELEVENLABS')
470
+
471
+ from elevenlabs import save
472
+ from elevenlabs.client import ElevenLabs
473
+ from IPython.display import Audio, display
474
+
475
+ sound_file = 'test.mp3'
476
+
477
+ def build_eleven_labs_query(username: str):
478
+ client = ElevenLabs(
479
+ api_key=eleven_labs_key,
480
+ )
481
+
482
+ audio = client.generate(
483
+ text=ssml_template.format(
484
+ alphabet=Alphabets.CMU,
485
+ phonetics=predict(model, username),
486
+ text=username
487
+ ),
488
+ voice="Rachel",
489
+ model="eleven_flash_v2"
490
+ )
491
+ save(audio, sound_file)
492
+
493
+ build_eleven_labs_query(test_username)
494
+
495
+
496
+ display(Audio(sound_file, autoplay=True))
497
+
498
+ # prompt: get the parameters of a pytorch model
499
+
500
+ import torch
501
+
502
+ # Assuming 'model' is your Seq2Seq model instance
503
+ # Replace with your actual model if named differently
504
+
505
+ # Method 1: Using model.named_parameters()
506
+ for name, param in model.named_parameters():
507
+ print(f"Parameter Name: {name}, Shape: {param.shape}")
508
+
509
+ # Method 2: Using model.parameters() (without parameter names)
510
+ for param in model.parameters():
511
+ print(f"Parameter Shape: {param.shape}")
512
+
513
+
514
+ print(f"Model Parameters: {sum(p.numel() for p in model.parameters())}")
515
+
516
+ # prompt: visualize the weights
517
+
518
+ import matplotlib.pyplot as plt
519
+ import numpy as np
520
+
521
+ # Assuming 'model' is your Seq2Seq model instance
522
+ # Replace with your actual model if named differently
523
+
524
+ # Collect parameter shapes and names
525
+ parameter_shapes = []
526
+ parameter_names = []
527
+ for name, param in model.named_parameters():
528
+ parameter_shapes.append(np.prod(param.shape))
529
+ parameter_names.append(name)
530
+
531
+ # Create a bar chart
532
+ plt.figure(figsize=(10, 6))
533
+ plt.bar(parameter_names, parameter_shapes)
534
+ plt.xlabel("Parameter Name")
535
+ plt.ylabel("Number of Weights")
536
+ plt.title("Distribution of Weights in the Model")
537
+ plt.xticks(rotation=90) # Rotate x-axis labels for better readability
538
+ plt.tight_layout()
539
+ plt.show()