Koyd111 commited on
Commit
32b6996
·
verified ·
1 Parent(s): 30c4f9b

Upload 8 files

Browse files
Files changed (8) hide show
  1. cipher_8bit.py +75 -0
  2. data_pairs.pkl +3 -0
  3. french_dataset.py +53 -0
  4. infer.py +193 -0
  5. invariants.py +227 -0
  6. preprocess_dataset.py +47 -0
  7. symbols.pkl +3 -0
  8. test2.py +267 -0
cipher_8bit.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import random
4
+ from collections import Counter
5
+
6
+ def cut_string_into_pairs(text_corpus):
7
+ pairs = []
8
+ for i in range(0, len(text_corpus) - 1, 2):
9
+ pairs.append(text_corpus[i:i + 2])
10
+ if len(text_corpus) % 2 != 0:
11
+ pairs.append(text_corpus[-1] + '_')
12
+ return pairs
13
+
14
+ def get_symbols(text_corpus, max_characters=256):
15
+ # Get all single unique characters, then fill in the rest of the symbol spots with the most common character pairs
16
+ single_characters = list(set(list(text_corpus)))
17
+ pairs = [item for item, _ in Counter(cut_string_into_pairs(text_corpus)).most_common(256 - len(single_characters))]
18
+ return single_characters + pairs
19
+
20
+ def substitution_cipher(symbols, random_seed):
21
+ random.seed(random_seed)
22
+ # Make a randomly ordered range from 0-255 (8-bits)
23
+ integer_encodings = random.sample(list(range(len(symbols))), len(symbols))
24
+ substitution_rule = dict({})
25
+ # Map every symbol to a unique encoding
26
+ for idx, symbol in enumerate(symbols):
27
+ encoding = integer_encodings[idx] # Get the random encoding for the symbol
28
+ substitution_rule[symbol] = encoding # Store the symbol as the key, encoding as the value
29
+ return substitution_rule
30
+
31
+ def encode_text_with_indices(rule, symbols, text):
32
+ encoded_text = []
33
+ indices = []
34
+ i = 0
35
+
36
+ # Create a reverse mapping from symbols to their indices
37
+ index_dict = dict(zip(symbols, range(len(symbols))))
38
+
39
+ while i < len(text):
40
+ # Check for pairs
41
+ if i + 1 < len(text):
42
+ pair = text[i] + text[i + 1]
43
+ # Check if the pair exists in the rule
44
+ if pair in rule:
45
+ encoding = rule[pair] # Get the encoding for the pair
46
+ encoded_text.append(encoding)
47
+ indices.append(index_dict[pair]) # Get the index of the symbol
48
+ i += 2 # Skip the two characters used in the pair
49
+ continue
50
+
51
+ # Single character substitution
52
+ if text[i] in rule:
53
+ encoding = rule[text[i]]
54
+ encoded_text.append(encoding)
55
+ indices.append(index_dict[text[i]]) # Get the index of the symbol
56
+ else:
57
+ # If the character doesn't exist in the rule, keep it as-is
58
+ encoded_text.append(256)
59
+ indices.append(256) # Use -1 or some other value to indicate no encoding for this character
60
+
61
+ i += 1
62
+
63
+ return encoded_text, indices
64
+
65
+ # Function to load or save the symbols as a pickle
66
+ def load_or_save_symbols(symbols, pickle_file_path="symbols.pkl"):
67
+ if os.path.exists(pickle_file_path):
68
+ with open(pickle_file_path, 'rb') as f:
69
+ print("Loading symbols from pickle file...")
70
+ return pickle.load(f)
71
+ else:
72
+ print("Pickle file not found. Saving symbols...")
73
+ with open(pickle_file_path, 'wb') as f:
74
+ pickle.dump(symbols, f)
75
+ return symbols
data_pairs.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1f11c77420da01df9611b237623d76640a0871a8f2a25def2f614f2b5233ec1
3
+ size 1964900268
french_dataset.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import requests
4
+
5
+ def clean_text(combined_text):
6
+ # Regex pattern to keep only English/French letters, punctuation, and newlines
7
+ allowed_chars_pattern = r"[A-Za-zÀ-ÿéèêàâçîôùûïëô\s\.,;!?()\"'\-\n]"
8
+
9
+ # Keep only the characters that match the allowed pattern
10
+ cleaned_text = ''.join(re.findall(allowed_chars_pattern, combined_text))
11
+
12
+ return cleaned_text
13
+
14
+ def fetch_gutenberg_text(book_id):
15
+ url = f"http://www.gutenberg.org/ebooks/{book_id}.txt.utf-8"
16
+ response = requests.get(url)
17
+ if response.status_code == 200:
18
+ return response.text
19
+ else:
20
+ print(f"Failed to fetch book ID {book_id}")
21
+ return ""
22
+
23
+ def get_full_dataset(file_path: str = "full_dataset.txt") -> str:
24
+ # List of French text IDs from Project Gutenberg
25
+ french_text_ids = ['796', '797', '798', '799', '800', '801', '802', '803', '1256', '1339', '1910', '2419', '2650', '2682', '2998', '2999', '3000', '3456', '3644', '3645', '4548', '4559', '4561', '4562', '4563', '4564', '4565', '4566', '4567', '4568', '4569', '4570', '4647', '4648', '4649', '4650', '4651', '4688', '4708', '4717', '4718', '4740', '4741', '4771', '4772', '4785', '4791', '4933', '4935', '4936', '4968', '5081', '5082', '5095', '5096', '5097', '5104', '5105', '5126', '5130', '5138', '5147', '5154', '5158', '5178', '5250', '5258', '5318', '5423', '5644', '5711', '5781', '5892', '6099', '6309', '6318', '6319', '6377', '6470', '6484', '6497', '6501', '6558', '6691', '6739', '6838', '6966', '6994', '7012', '7173', '7263', '7268', '7442', '7461', '7462', '7770', '7771', '7772', '7809', '7812', '7818', '7854', '8074', '8173', '8174', '8175', '8186', '8416', '8453', '8454', '8490', '8520', '8524', '8541', '8560', '8561', '8563', '8650', '8692', '8693', '8712', '8719', '8739', '8822', '8863', '8864', '8876', '8907', '8946', '9053', '9261', '9262', '9453', '9637', '9638', '9639', '9643', '9644', '9645', '9655', '9818', '9824', '9891', '9892', '9893', '9894', '9945', '9976', '10053', '10061', '10160', '10263', '10289', '10290', '10346', '10384', '10385', '10442', '10604', '10678', '10680', '10682', '10683', '10685', '10687', '10689', '10697', '10746', '10764', '10768', '10774', '10775', '10824', '10841', '10906', '10953', '10982', '11035', '11036', '11037', '11038', '11040', '11042', '11046', '11048', '11049', '11131', '11132', '11175', '11176', '11178', '11199', '11300', '11301', '11380', '11423', '11434', '11450', '11453', '11484', '11494', '11495', '11586', '11588', '11589', '11590', '11596', '11597', '11621', '11622', '11645', '11646', '11650', '11678', '11714', '11744', '11747', '11766', '11767', '11769', '11770', '11893', '11905', '11927', '11928', '11964', '12005', '12011', '12065', '12072', '12080', '12105', '12120', '12174', '12230', '12246', '12247', '12250', '12251', '12258', '12271', '12284', '12289', '12295', '12301', '12331', '12332', '12338', '12356', '12365', '12367', '12399', '12401', '12437', '12447', '12448', '12451', '12459', '12487', '12488', '12489', '12533', '12534', '12562', '12566', '12602', '12603', '12620', '12646', '12665', '12666', '12726', '12727', '12749', '12751', '12752', '12782', '12783', '12812', '12829', '12837', '12862', '12865', '12869', '12889', '12893', '12949', '12950', '12969', '12979', '12993', '12999', '13013', '13016', '13024', '13025', '13027', '13036', '13038', '13059', '13070', '13096', '13122', '13149', '13187', '13189', '13190', '13192', '13198', '13207', '13219', '13221', '13230', '13231', '13247', '13256', '13258', '13263', '13284', '13299', '13303', '13336', '13339', '13374', '13380', '13383', '13385', '13400', '13431', '13456', '13475', '13478', '13490', '13523', '13524', '13525', '13557', '13562', '13592', '13594', '13598', '13607', '13622', '13628', '13629', '13653', '13654', '13668', '13671', '13676', '13703', '13704', '13727', '13734', '13735', '13737', '13743', '13744', '13765', '13771', '13772', '13792', '13793', '13794', '13795', '13798', '13804', '13807', '13808', '13818', '13819', '13825', '13834', '13837', '13838', '13839', '13846', '13848', '13855', '13856', '13857', '13861', '13862', '13863', '13866', '13868', '13875', '13892', '13914', '13917', '13938', '13947', '13948', '13949', '13950', '13951', '13952', '13965', '13981', '14030', '14038', '14059', '14069', '14071', '14113', '14151', '14155', '14156', '14157', '14158', '14159', '14162', '14163', '14192', '14247', '14251', '14258', '14259', '14285', '14286', '14287', '14288', '14309', '14310', '14343', '14372', '14397', '14398', '14399', '14404', '14512', '14536', '14537', '14538', '14539', '14541', '14564', '14609', '14677', '14683', '14688', '14692', '14693', '14702', '14703', '14704', '14705', '14713', '14720', '14751', '14788', '14789', '14790', '14791', '14792', '14793', '14799', '14803', '14804', '14805', '14806', '14810', '14820', '14827', '14828', '14905', '14911', '14912', '14913', '14918', '15032', '15057', '15058', '15059', '15060', '15071', '15075', '15107', '15112', '15113', '15146', '15150', '15152', '15203', '15208', '15212', '15226', '15235', '15239', '15286', '15287', '15288', '15295', '15296', '15297', '15303', '15305', '15312', '15324', '15361', '15371', '15372', '15375', '15388', '15397', '15433', '15434', '15458', '15459', '15462', '15463', '15543', '15554', '15555', '15556', '15557', '15558', '15574', '15579', '15584', '15589', '15593', '15598', '15626', '15635', '15642', '15645', '15646', '15686', '15732', '15739', '15804', '15805', '15811', '15815', '15816', '15823', '15844', '15846', '15847', '15848', '15849', '15871', '15882', '15885', '15907', '15942', '15943', '16020', '16021', '16022', '16023', '16066', '16067', '16128', '16210', '16234', '16235', '16236', '16237', '16238', '16239', '16240', '16260', '16286', '16336', '16388', '16421', '16465', '16492', '16499', '16649', '16709', '16710', '16743', '16758', '16789', '16795', '16796', '16812', '16813', '16814', '16815', '16816', '16817', '16818', '16819', '16820', '16824', '16825', '16826', '16827', '16828', '16848', '16849', '16850', '16851', '16852', '16862', '16874', '16875', '16876', '16883', '16884', '16885', '16886', '16887', '16888', '16901', '16934', '16988', '16989', '17004', '17010', '17044', '17098', '17105', '17106', '17123', '17140', '17184', '17225', '17230', '17231', '17232', '17233', '17234', '17235', '17236', '17238', '17240', '17242', '17248', '17251', '17252', '17258', '17261', '17264', '17267', '17271', '17281', '17285', '17298', '17311', '17319', '17335', '17344', '17345', '17353', '17360', '17363', '17372', '17399', '17419', '17420', '17457', '17458', '17459', '17489', '17493', '17494', '17501', '17505', '17509', '17516', '17517', '17518', '17519', '17533', '17538', '17540', '17541', '17542', '17543', '17550', '17551', '17552', '17553', '17555', '17557', '17561', '17565', '17573', '17577', '17578', '17589', '17590', '17602', '17605', '17623', '17631', '17632', '17640', '17641', '17643', '17646', '17656', '17660', '17661', '17662', '17668', '17670', '17673', '17675', '17676', '17688', '17691', '17692', '17693', '17696', '17707', '17708', '17709', '17714', '17715', '17716', '17717', '17734', '17736', '17738', '17739', '17746', '17747', '17752', '17757', '17758', '17791', '17794', '17795', '17796', '17798', '17808', '17809', '17810', '17828', '17830', '17831', '17832', '17840', '17868', '17869', '17879', '17880', '17899', '17911', '17914', '17915', '17916', '17930', '17940', '17941', '17942', '17947', '17949', '17950', '17951', '17963', '17980', '17983', '17984', '17989', '17990', '17991', '17992', '18003', '18006', '18014', '18015', '18024', '18027', '18028', '18029', '18034', '18055', '18059', '18061', '18064', '18067', '18073', '18074', '18075', '18081', '18083', '18084', '18085', '18089', '18090', '18092', '18106', '18108', '18111', '18112', '18121', '18123', '18133', '18142', '18143', '18152', '18159', '18162', '18169', '18179', '18197', '18199', '18200', '18205', '18208', '18211', '18215', '18244', '18245', '18262', '18263', '18271', '18294', '18295', '18296', '18302', '18311', '18312', '18313', '18321', '18340', '18353', '18358', '18367', '18368', '18401', '18402', '18403', '18404', '18407', '18415', '18416', '18427', '18446', '18454', '18455', '18490', '18491', '18494', '18518', '18535', '18537', '18543', '18583', '18585', '18586', '18610', '18611', '18623', '18627', '18669', '18672', '18692', '18693', '18695', '18697', '18715', '18716', '18717', '18718', '18724', '18727', '18738', '18749', '18771', '18773', '18797', '18806', '18812', '18825', '18826', '18849', '18864', '18865', '18889', '18890', '18899', '18918', '18919', '18920', '18921', '18922', '18923', '18924', '18925', '18940', '18942', '18944', '18962', '18983', '18995', '18996', '19008', '19021', '19035', '19045', '19075', '19112', '19124', '19149', '19152', '19184', '19186', '19187', '19201', '19219', '19227', '19228', '19232', '19233', '19234', '19248', '19249', '19266', '19344', '19345', '19431', '19440', '19454', '19455', '19483', '19497', '19519', '19536', '19540', '19588', '19604', '19631', '19657', '19662', '19689', '19700', '19738', '19756', '19820', '19842', '19854', '19862', '19919', '19920', '19954', '19956', '19972', '19982', '19984', '19992', '20013', '20077', '20079', '20108', '20143', '20199', '20234', '20244', '20246', '20254', '20262', '20325', '20372', '20396', '20398', '20414', '20415', '20440', '20441', '20457', '20479', '20490', '20498', '20507', '20554', '20562', '20564', '20568', '20577', '20623', '20635', '20640', '20664', '20690', '20700', '20703', '20705', '20720', '20761', '20773', '20790', '20823', '20824', '20825', '20829', '20864', '20865', '20886', '20894', '20895', '20949', '20950', '20964', '20966', '20971', '20972', '20973', '20974', '21001', '21013', '21017', '21023', '21124', '21191', '21199', '21215', '21221', '21257', '21277', '21343', '21413', '21544', '21669', '21792', '21804', '21825', '21856', '21896', '21912', '21940', '21966', '22007', '22011', '22016', '22039', '22048', '22054', '22068', '22077', '22078', '22111', '22192', '22253', '22262', '22266', '22268', '22356', '22383', '22384', '22385', '22386', '22388', '22393', '22394', '22416', '22429', '22543', '22548', '22551', '22552', '22558', '22572', '22575', '22613', '22618', '22633', '22741', '22751', '22760', '22768', '22769', '22780', '22813', '22830', '22889', '22917', '22918', '22971', '23019', '23020', '23047', '23098', '23158', '23199', '23202', '23211', '23279', '23285', '23289', '23423', '23444', '23463', '23484', '23508', '23520', '23566', '23567', '23578', '23582', '23583', '23589', '23596', '23610', '23615', '23616', '23618', '23654', '23801', '23828', '23830', '23848', '23917', '23939', '23940', '23953', '23954', '24007', '24045', '24081', '24123', '24217', '24243', '24255', '24257', '24260', '24300', '24305', '24320', '24325', '24369', '24383', '24424', '24490', '24515', '24546', '24549', '24555', '24573', '24636', '24766', '24768', '24809', '24850', '24861', '24867', '24888', '24908', '24915', '24924', '24960', '24962', '25036', '25039', '25097', '25149', '25227', '25276', '25310', '25335', '25364', '25370', '25378', '25382', '25394', '25403', '25434', '25435', '25503', '25526', '25575', '25576', '25615', '25616', '25680', '25694', '25704', '25707', '25715', '25734', '25752', '25755', '25756', '25839', '25850', '25863', '25949', '25981', '26082', '26091', '26092', '26101', '26118', '26124', '26211', '26296', '26336', '26350', '26351', '26352', '26353', '26362', '26363', '26370', '26375', '26376', '26394', '26400', '26415', '26418', '26432', '26435', '26436', '26456', '26476', '26488', '26489', '26504', '26510', '26511', '26515', '26531', '26562', '26566', '26567', '26571', '26607', '26608', '26609', '26614', '26634', '26648', '26680', '26681', '26685', '26710', '26712', '26720', '26721', '26749', '26757', '26758', '26759', '26762', '26763', '26764', '26765', '26766', '26780', '26804', '26806', '26807', '26808', '26809', '26810', '26811', '26812', '26813', '26814', '26815', '26816', '26817', '26818', '26819', '26820', '26821', '26822', '26823', '26824', '26825', '26826', '26827', '26859', '26863', '26868', '26891', '26894', '26943', '27018', '27029', '27031', '27033', '27036', '27037', '27038', '27040', '27041', '27042', '27043', '27044', '27046', '27047', '27131', '27132', '27133', '27134', '27144', '27186', '27191', '27215', '27267', '27269', '27278', '27281', '27282', '27283', '27296', '27303', '27304', '27308', '27313', '27314', '27345', '27379', '27380', '27381', '27427', '27451', '27566', '27573', '27574', '27610', '27617', '27619', '27623', '27625', '27626', '27627', '27641', '27644', '27694', '27772', '27773', '27774', '27782', '27806', '27807', '27808', '27809', '27828', '27831', '27837', '27840', '27843', '27844', '27854', '27855', '27876', '27878', '27904', '27905', '27931', '27936', '27970', '27976', '28078', '28080', '28081', '28082', '28113', '28114', '28124', '28150', '28151', '28176', '28200', '28210', '28211', '28217', '28227', '28230', '28249', '28254', '28258', '28286', '28332', '28358', '28370', '28373', '28397', '28412', '28427', '28429', '28485', '28519', '28523', '28534', '28559', '28568', '28578', '28602', '28603', '28604', '28605', '28622', '28623', '28624', '28702', '28718', '28787', '28788', '28789', '28827', '28828', '28829', '28891', '28930', '28937', '28977', '29012', '29013', '29049', '29052', '29094', '29114', '29164', '29169', '29175', '29179', '29191', '29251', '29279', '29282', '29302', '29332', '29397', '29398', '29476', '29523', '29536', '29537', '29538', '29539', '29549', '29565', '29613', '29651', '29755', '29758', '29772', '29775', '29781', '29783', '29800', '29802', '29805', '29823', '29825', '29826', '29843', '29844', '29857', '29887', '29900', '29918', '29922', '29923', '29924', '29925', '29933', '29937', '29943', '29950', '29956', '29985', '29986', '30008', '30009', '30013', '30021', '30046', '30067', '30117', '30144', '30195', '30196', '30211', '30226', '30268', '30317', '30363', '30395', '30423', '30484', '30512', '30513', '30514', '30515', '30516', '30517', '30518', '30519', '30520', '30521', '30553', '30582', '30587', '30602', '30603', '30604', '30633', '30638', '30654', '30696', '30702', '30703', '30779', '30781', '30782', '30783', '30784', '30785', '30786', '30787', '30788', '30789', '30831', '30904', '30906', '30912', '30913', '30915', '30917', '30918', '30922', '30923', '30930', '30949', '30977', '30978', '30994', '31022', '31032', '31042', '31054', '31069', '31070', '31117', '31137', '31154', '31176', '31260', '31295', '31432', '31440', '31441', '31474', '31475', '31505', '31559', '31600', '31628', '31634', '31636', '31720', '31725', '31746', '31800', '31805', '31817', '31846', '31863', '31881', '31882', '31883', '31904', '31918', '31931', '31939', '31944', '31947', '31952', '31983', '31988', '32056', '32065', '32113', '32138', '32194', '32244', '32297', '32298', '32348', '32349', '32509', '32621', '32640', '32643', '32798', '32808', '32854', '32948', '32952', '32963', '33031', '33032', '33033', '33069', '33083', '33106', '33132', '33157', '33184', '33205', '33229', '33250', '33258', '33315', '33316', '33339', '33378', '33388', '33408', '33414', '33422', '33440', '33454', '33462', '33463', '33489', '33518', '33534', '33539', '33580', '33590', '33595', '33633', '33655', '33675', '33692', '33693', '33699', '33711', '33734', '33738', '33744', '33745', '33746', '33796', '33807', '33808', '33832', '33840', '33851', '33856', '33861', '33869', '33875', '33881', '33882', '33893', '33894', '33895', '34008', '34119', '34204', '34212', '34231', '34264', '34285', '34301', '34332', '34342', '34349', '34351', '34354', '34363', '34364', '34382', '34385', '34389', '34422', '34432', '34435', '34445', '34451', '34456', '34469', '34496', '34516', '34528', '34547', '34559', '34560', '34561', '34564', '34608', '34620', '34633', '34648', '34692', '34693', '34708', '34715', '34783', '34800', '34803', '34841', '34872', '34918', '34976', '34991', '34998', '35005', '35010', '35019', '35028', '35052', '35054', '35064', '35089', '35100', '35103', '35124', '35129', '35134', '35150', '35151', '35163', '35166', '35200', '35209', '35223', '35235', '35262', '35267', '35285', '35286', '35309', '35315', '35319', '35343', '35376', '35390', '35404', '35406', '35444', '35445', '35446', '35476', '35482', '35498', '35525', '35568', '35609', '35643', '35657', '35718', '35732', '35754', '35766', '35798', '35814', '35825', '35827', '35840', '35854', '35871', '35876', '35878', '35880', '35897', '35908', '35919', '35929', '35938', '35951', '35955', '35969', '35971', '35979', '35986', '35988', '36001', '36011', '36025', '36058', '36086', '36207', '36260', '36315', '36316', '36331', '36334', '36352', '36357', '36369', '36371', '36380', '36394', '36413', '36416', '36436', '36437', '36447', '36454', '36455', '36460', '36468', '36469', '36477', '36482', '36510', '36528', '36596', '36630', '36635', '36647', '36676', '36704', '36706', '36708', '36729', '36738', '36742', '36777', '36780', '36786', '36806', '36807', '36812', '36814', '36826', '36868', '36894', '36900', '36909', '36910', '36911', '36938', '36941', '36947', '36972', '36978', '36987', '37011', '37040', '37050', '37051', '37052', '37053', '37075', '37076', '37084', '37096', '37133', '37135', '37138', '37183', '37184', '37201', '37248', '37273', '37305', '37306', '37319', '37384', '37401', '37417', '37428', '37468', '37473', '37491', '37506', '37524', '37526', '37534', '37555', '37567', '37569', '37577', '37601', '37604', '37608', '37616', '37617', '37630', '37634', '37654', '37678', '37733', '37760', '37762', '37769', '37771', '37798', '37799', '37805', '37836', '37851', '37874', '37886', '37896', '37914', '37941', '37951', '37971', '37987', '37989', '37990', '38002', '38031', '38034', '38042', '38057', '38058', '38059', '38074', '38089', '38118', '38122', '38150', '38159', '38166', '38210', '38225', '38242', '38243', '38244', '38256', '38257', '38258', '38271', '38313', '38316', '38320', '38335', '38358', '38361', '38362', '38385', '38392', '38400', '38435', '38442', '38493', '38499', '38527', '38543', '38548', '38576', '38581', '38618', '38639', '38643', '38660', '38674', '38696', '38704', '38705', '38706', '38712', '38725', '38729', '38734', '38736', '38737', '38756', '38760', '38778', '38797', '38842', '38849', '38867', '38868', '38883', '38912', '38913', '38925', '38935', '38971', '38974', '38987', '38996', '39016', '39024', '39031', '39071', '39101', '39117', '39153', '39156', '39165', '39171', '39173', '39183', '39201', '39220', '39240', '39241', '39242', '39256', '39311', '39314', '39320', '39325', '39327', '39328', '39331', '39335', '39360', '39363', '39405', '39410', '39429', '39436', '39449', '39481', '39492', '39512', '39513', '39533', '39555', '39589', '39637', '39654', '39679', '39687', '39694', '39719', '39738', '39739', '39765', '39774', '39811', '39825', '39835', '39836', '39837', '39876', '39877', '39880', '39884', '39885', '39889', '39901', '39921', '39953', '39976', '40011', '40025', '40049', '40052', '40085', '40086', '40095', '40107', '40172', '40189', '40193', '40194', '40195', '40213', '40239', '40247', '40248', '40272', '40279', '40308', '40374', '40399', '40413', '40422', '40456', '40489', '40496', '40516', '40530', '40551', '40625', '40694', '40695', '40707', '40720', '40750', '40763', '40768', '40778', '40855', '40877', '40902', '40916', '41038', '41039', '41054', '41056', '41065', '41066', '41079', '41080', '41087', '41099', '41100', '41101', '41112', '41113', '41114', '41116', '41121', '41123', '41124', '41147', '41155', '41208', '41211', '41226', '41251', '41307', '41322', '41336', '41385', '41413', '41428', '41486', '41577', '41578', '41614', '41644', '41692', '41731', '41738', '41769', '41815', '41843', '41872', '41903', '41943', '41965', '41967', '41968', '41969', '41970', '41984', '41986', '41991', '41992', '42021', '42036', '42063', '42064', '42088', '42126', '42131', '42141', '42151', '42177', '42186', '42192', '42211', '42256', '42292', '42297', '42298', '42300', '42310', '42377', '42421', '42432', '42463', '42464', '42472', '42487', '42497', '42524', '42525', '42554', '42561', '42563', '42586', '42590', '42592', '42594', '42615', '42624', '42627', '42635', '42636', '42637', '42648', '42659', '42662', '42663', '42675', '42694', '42695', '42711', '42743', '42744', '42765', '42798', '42836', '42852', '42896', '42924', '42939', '42986', '43003', '43004', '43047', '43196', '43233', '43239', '43277', '43291', '43294', '43306', '43307', '43308', '43309', '43310', '43311', '43312', '43313', '43315', '43321', '43389', '43408', '43436', '43441', '43475', '43476', '43501', '43535', '43554', '43561', '43567', '43632', '43652', '43676', '43698', '43712', '43718', '43734', '43748', '43752', '43760', '43761', '43767', '43772', '43782', '43784', '43787', '43822', '43823', '43839', '43848', '43849', '43851', '43871', '43889', '43901', '43923', '43924', '43926', '43956', '43960', '43964', '43980', '44017', '44023', '44054', '44068', '44070', '44095', '44098', '44139', '44141', '44142', '44156', '44160', '44161', '44162', '44180', '44181', '44187', '44198', '44199', '44232', '44236', '44242', '44244', '44255', '44260', '44277', '44285', '44300', '44301', '44310', '44323', '44346', '44354', '44356', '44357', '44359', '44373', '44390', '44402', '44403', '44407', '44453', '44467', '44468', '44473', '44483', '44488', '44504', '44519', '44543', '44589', '44613', '44617', '44620', '44634', '44664', '44675', '44676', '44689', '44696', '44697', '44710', '44713', '44715', '44723', '44749', '44756', '44762', '44807', '44812', '44861', '44874', '44877', '44906', '44911', '44931', '44958', '44960', '45012', '45013', '45014', '45015', '45031', '45036', '45058', '45060', '45119', '45121', '45150', '45176', '45207', '45211', '45212', '45213', '45312', '45323', '45374', '45411', '45428', '45437', '45458', '45468', '45501', '45513', '45533', '45550', '45575', '45590', '45607', '45633', '45650', '45655', '45679', '45692', '45704', '45715', '45722', '45753', '45787', '45855', '45864', '45879', '45886', '45892', '45911', '45935', '45947', '45953', '45967', '45969', '45970', '45979', '46003', '46022', '46033', '46065', '46111', '46137', '46142', '46183', '46224', '46245', '46253', '46282', '46285', '46314', '46350', '46357', '46364', '46373', '46387', '46444', '46447', '46461', '46469', '46470', '46490', '46525', '46527', '46541', '46557', '46598', '46604', '46625', '46626', '46627', '46628', '46631', '46632', '46640', '46646', '46673', '46683', '46687', '46718', '46721', '46747', '46765', '46828', '46832', '46837', '46842', '46870', '46876', '46891', '46902', '46907', '46916', '46931', '46932', '46991', '47009', '47042', '47062', '47124', '47131', '47133', '47164', '47171', '47188', '47196', '47207', '47255', '47301', '47321', '47324', '47329', '47345', '47350', '47352', '47360', '47395', '47404', '47414', '47459', '47468', '47469', '47473', '47477', '47479', '47514', '47522', '47552', '47623', '47632', '47645', '47680', '47712', '47717', '47720', '47740', '47743', '47756', '47783', '47802', '47846', '47903', '47918', '47919', '47963', '47969', '47982', '47984', '47986', '47987', '48004', '48006', '48011', '48061', '48082', '48135', '48212', '48257', '48259', '48279', '48282', '48359', '48368', '48383', '48401', '48421', '48432', '48462', '48477', '48518', '48519', '48520', '48529', '48581', '48590', '48650', '48683', '48684', '48687', '48781', '48830', '48852', '48855', '48878', '48881', '48933', '49004', '49005', '49168', '49293', '49312', '49355', '49397', '49398', '49399', '49408', '49409', '49441', '49446', '49482', '49502', '49542', '49573', '49574', '49575', '49586', '49619', '49712', '49725', '49743', '49761', '49773', '49783', '49813', '49855', '49878', '49887', '49943', '49977', '50007', '50024', '50069', '50083', '50144', '50154', '50173', '50208', '50211', '50267', '50278', '50340', '50356', '50398', '50435', '50447', '50485', '50513', '50521', '50580', '50593', '50594', '50605', '50633', '50655', '50664', '50708', '50718', '50725', '50743', '50786', '50805', '50838', '50926', '50930', '50974', '50997', '51005', '51012', '51013', '51014', '51023', '51083', '51084', '51120', '51144', '51149', '51156', '51162', '51178', '51179', '51214', '51225', '51227', '51236', '51237', '51266', '51269', '51312', '51313', '51338', '51364', '51372', '51373', '51381', '51405', '51423', '51505', '51515', '51516', '51591', '51606', '51612', '51631', '51632', '51659', '51666', '51703', '51709', '51725', '51787', '51802', '51826', '51837', '51876', '51977', '52006', '52011', '52065', '52123', '52140', '52145', '52282', '52288', '52331', '52332', '52376', '52379', '52380', '52428', '52431', '52443', '52487', '52488', '52489', '52511', '52520', '52565', '52585', '52592', '52611', '52629', '52635', '52669', '52707', '52753', '52797', '52829', '52831', '52843', '52880', '52893', '52919', '52927', '52929', '52933', '52950', '52990', '53082', '53147', '53183', '53247', '53279', '53284', '53309', '53310', '53321', '53331', '53374', '53399', '53439', '53497', '53502', '53503', '53523', '53536', '53540', '53564', '53595', '53599', '53634', '53640', '53694', '53710', '53722', '53737', '53749', '53761', '53779', '53805', '53848', '53950', '54002', '54020', '54035', '54087', '54089', '54182', '54183', '54202', '54231', '54308', '54311', '54339', '54387', '54393', '54397', '54419', '54421', '54456', '54467', '54482', '54528', '54551', '54600', '54659', '54723', '54819', '54835', '54873', '54952', '54963', '54972', '54975', '54983', '54997', '55023', '55028', '55072', '55111', '55135', '55136', '55167', '55175', '55233', '55259', '55265', '55367', '55430', '55446', '55456', '55483', '55501', '55517', '55533', '55554', '55569', '55637', '55638', '55639', '55640', '55659', '55700', '55733', '55766', '55836', '55855', '55860', '55869', '55879', '56054', '56072', '56149', '56173', '56212', '56265', '56309', '56327', '56374', '56390', '56393', '56461', '56473', '56474', '56511', '56515', '56518', '56523', '56545', '56558', '56622', '56645', '56646', '56668', '56708', '56801', '56808', '56820', '56859', '56892', '56918', '56919', '56952', '56990', '56997', '57019', '57023', '57075', '57182', '57193', '57204', '57262', '57270', '57360', '57373', '57419', '57420', '57425', '57429', '57430', '57449', '57454', '57461', '57506', '57525', '57547', '57567', '57607', '57655', '57656', '57687', '57712', '57719', '57743', '57745', '57746', '57747', '57762', '57766', '57775', '57788', '57789', '57790', '57824', '57832', '57839', '57856', '57865', '57866', '57870', '57878', '57879', '57880', '57892', '57919', '57964', '57992', '58063', '58069', '58084', '58088', '58090', '58149', '58154', '58170', '58211', '58244', '58251', '58254', '58259', '58260', '58265', '58290', '58299', '58309', '58317', '58362', '58366', '58427', '58467', '58476', '58501', '58532', '58661', '58698', '58706', '58801', '58818', '58828', '59073', '59140', '59151', '59327', '59365', '59431', '59442', '59525', '59719', '59781', '59859', '59889', '59905', '59926', '59992', '59996', '60030', '60033', '60052', '60079', '60080', '60106', '60199', '60204', '60223', '60323', '60347', '60354', '60355', '60366', '60383', '60394', '60417', '60450', '60551', '60571', '60589', '60594', '60610', '60616', '60618', '60635', '60656', '60665', '60666', '60680', '60711', '60720', '60738', '60739', '60746', '60752', '60798', '60806', '60810', '60812', '60821', '60827', '60828', '60841', '60847', '60850', '60863', '60882', '60891', '60892', '60896', '60916', '60918', '60924', '60986', '60998', '61008', '61035', '61039', '61059', '61075', '61088', '61091', '61108', '61134', '61160', '61162', '61181', '61218', '61227', '61239', '61248', '61258', '61270', '61274', '61311', '61318', '61351', '61357', '61373', '61383', '61390', '61404', '61408', '61411', '61418', '61458', '61460', '61472', '61485', '61489', '61500', '61524', '61527', '61528', '61538', '61565', '61571', '61602', '61627', '61636', '61639', '61642', '61664', '61666', '61675', '61691', '61697', '61718', '61733', '61738', '61745', '61747', '61768', '61772', '61789', '61793', '61816', '61856', '61868', '61876', '61905', '61920', '61961', '61965', '61970', '61980', '61994', '62007', '62013', '62021', '62024', '62051', '62092', '62100', '62106', '62108', '62114', '62120', '62147', '62179', '62196', '62207', '62215', '62216', '62243', '62248', '62272', '62281', '62289', '62290', '62318', '62337', '62340', '62368', '62393', '62399', '62404', '62405', '62406', '62508', '62581', '62615', '62679', '62753', '62787', '62812', '62874', '62922', '62933', '62960', '63063', '63129', '63137', '63141', '63144', '63147', '63151', '63161', '63167', '63174', '63179', '63185', '63187', '63193', '63201', '63206', '63220', '63222', '63244', '63248', '63253', '63259', '63260', '63267', '63271', '63284', '63285', '63296', '63303', '63305', '63319', '63329', '63341', '63349', '63380', '63409', '63435', '63470', '63478', '63496', '63543', '63575', '63576', '63634', '63646', '63647', '63661', '63674', '63679', '63712', '63714', '63726', '63734', '63762', '63768', '63773', '63777', '63794', '63804', '63846', '63881', '63887', '63904', '63905', '63906', '63914', '63915', '63920', '63937', '63951', '63979', '64008', '64023', '64034', '64065', '64066', '64067', '64084', '64086', '64091', '64113', '64116', '64119', '64129', '64130', '64144', '64145', '64153', '64165', '64168', '64202', '64213', '64220', '64222', '64232', '64274', '64289', '64290', '64305', '64322', '64369', '64371', '64423', '64427', '64428', '64453', '64582', '64590', '64598', '64614', '64645', '64663', '64674', '64706', '64719', '64721', '64729', '64737', '64760', '64776', '64787', '64792', '64798', '64805', '64811', '64821', '64830', '64838', '64850', '64853', '64862', '64867', '64886', '64902', '64909', '64912', '64913', '64920', '64935', '64939', '64949', '64952', '64966', '64976', '64983', '64990', '65025', '65031', '65052', '65059', '65082', '65092', '65096', '65105', '65109', '65111', '65118', '65150', '65178', '65183', '65187', '65213', '65219', '65224', '65251', '65263', '65266', '65273', '65275', '65301', '65339', '65349', '65366', '65402', '65403', '65405', '65419', '65420', '65421', '65434', '65448', '65449', '65457', '65491', '65494', '65522', '65530', '65546', '65578', '65595', '65684', '65751', '65845', '65878', '65969', '65989', '65990', '66035', '66117', '66261', '66430', '66480', '66483', '66505', '66557', '66645', '66652', '66664', '66665', '66672', '66674', '66681', '66682', '66695', '66704', '66709', '66715', '66725', '66761', '66771', '66776', '66781', '66787', '66793', '66795', '66805', '66810', '66817', '66827', '66834', '66839', '66845', '66852', '66853', '66878', '66883', '66894', '66897', '66912', '66924', '66927', '66929', '66965', '66978', '66980', '66985', '66995', '67010', '67024', '67052', '67074', '67080', '67086', '67094', '67096', '67099', '67102', '67119', '67120', '67129', '67136', '67158', '67163', '67205', '67233', '67243', '67245', '67252', '67253', '67260', '67264', '67273', '67275', '67281', '67283', '67320', '67327', '67382', '67387', '67400', '67402', '67427', '67469', '67620', '67719', '67741', '67779', '67848', '67853', '67860', '67863', '67867', '67884', '67889', '67910', '67924', '67927', '67930', '67932', '67933', '67940', '67963', '67999', '68010', '68036', '68086', '68138', '68199', '68245', '68264', '68265', '68271', '68281', '68286', '68297', '68298', '68303', '68307', '68312', '68314', '68327', '68355', '68356', '68476', '68485', '68487', '68501', '68510', '68516', '68528', '68560', '68581', '68588', '68591', '68595', '68602', '68603', '68606', '68618', '68621', '68622', '68632', '68634', '68636', '68637', '68675', '68701', '68710', '68719', '68728', '68764', '68818', '68821', '68839', '68848', '68856', '68865', '68871', '68872', '68885', '69007', '69031', '69059', '69089', '69098', '69165', '69182', '69189', '69194', '69301', '69437', '69536', '69540', '69541', '69561', '69568', '69583', '69597', '69605', '69618', '69621', '69626', '69631', '69743', '69770', '69794', '69802', '69848', '69860', '69863', '69872', '69887', '69915', '69917', '69918', '69936', '69982', '70042', '70061', '70065', '70071', '70081', '70084', '70088', '70094', '70095', '70102', '70117', '70126', '70144', '70161', '70162', '70186', '70201', '70202', '70241', '70260', '70268', '70299', '70312', '70319', '70327', '70340', '70351', '70354', '70363', '70369', '70396', '70422', '70423', '70504', '70505', '70534', '70538', '70612', '70650', '70661', '70676', '70746', '70747', '70752', '70753', '70754', '70762', '70778', '70789', '70801', '70818', '70821', '70825', '70846', '70878', '70891', '70900', '70904', '70917', '70954', '70956', '70989', '70995', '71022', '71023', '71033', '71053', '71054', '71059', '71062', '71093', '71110', '71123', '71143', '71172', '71199', '71204', '71206', '71208', '71225', '71227', '71236', '71251', '71272', '71287', '71293', '71294', '71296', '71298', '71299', '71300', '71331', '71341', '71420', '71429', '71445', '71455', '71457', '71473', '71487', '71499', '71504', '71510', '71553', '71555', '71588', '71635', '71658', '71686', '71715', '71731', '71737', '71738', '71764', '71773', '71778', '71820', '71831', '71832', '71878', '71901', '71911', '71926', '71932', '72023', '72024', '72034', '72035', '72036', '72055', '72071', '72091', '72127', '72145', '72151', '72160', '72169', '72189', '72194', '72205', '72244', '72251', '72252', '72256', '72260', '72366', '72370', '72385', '72393', '72414', '72486', '72534', '72605', '72618', '72619', '72620', '72630', '72637', '72641', '72643', '72671', '72702', '72712', '72726', '72759', '72773', '72784', '72798', '72799', '72802', '72804', '72805', '72808', '72819', '72836', '72844', '72867', '72872', '72885', '72889', '72905', '72912', '72931', '72939', '72954', '72973', '72975', '72978', '72982', '72984', '72990', '72992', '73002', '73007', '73010', '73018', '73022', '73024', '73033', '73039', '73058', '73077', '73097', '73103', '73118', '73164', '73168', '73191', '73199', '73206', '73210', '73232', '73235', '73252', '73279', '73282', '73291', '73294', '73329', '73333', '73348', '73349', '73370', '73384', '73406', '73436', '73443', '73454', '73478', '73481', '73526', '73552', '73562', '73571', '73632', '73633', '73653', '73689', '73690', '73732', '73768', '73831', '73832', '73847', '73864', '73865', '73874', '73884', '73889', '73894', '73904', '73912', '73918', '73928', '73946', '73948', '73967', '73977', '73985', '73998', '73999', '74003', '74034', '74035', '74040', '74052', '74061', '74080', '74081', '74090', '74091', '74092', '74093', '74094', '74106', '74126', '74128', '74131', '74141', '74143', '74157', '74164', '74208', '74209', '74216', '74225', '74228', '74285', '74286', '74290', '74327', '74329', '74351', '74364', '74384', '74398', '74404', '74411', '74412', '74420', '74426', '74429', '74434', '74446', '74448', '74453', '74455', '74459', '74465', '74481', '74487', '74488', '74495', '74499', '74506', '74509', '74513', '74562', '74569', '74580', '74599', '74628']
26
+ # Check if the file exists
27
+ if os.path.exists(file_path):
28
+ # Load and return the existing text file
29
+ with open(file_path, 'r', encoding='utf-8') as file:
30
+ return file.read()
31
+
32
+ # Initialize an empty string to hold all the text
33
+ combined_text = ""
34
+
35
+ # Fetch and append each text
36
+ print("total books", len(french_text_ids))
37
+ i = 0
38
+ for text_id in french_text_ids:
39
+ i += 1
40
+ if i % 15 == 0:
41
+ print(i)
42
+ text = fetch_gutenberg_text(text_id)
43
+ if text:
44
+ combined_text += text + "\n\n" # Separate texts with double newlines
45
+
46
+ # Clean the text before returning
47
+ full_text = clean_text(combined_text)
48
+
49
+ # Save the combined text to a file
50
+ with open(file_path, 'w', encoding='utf-8') as file:
51
+ file.write(full_text)
52
+
53
+ return full_text
infer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from preprocess_dataset import preprocess_text
5
+ from torch import Tensor
6
+ from torch.nn import Transformer
7
+ import math
8
+ import bitsandbytes as bnb
9
+ from invariants import get_data_pairs
10
+ from french_dataset import get_full_dataset
11
+ import numpy as np
12
+
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ MAX_SEQUENCE_LENGTH = 512
15
+ training_pairs = get_data_pairs(get_full_dataset())
16
+ PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 256, 699, 257
17
+
18
+ class CipherDataset(Dataset):
19
+ def __init__(self, pairs):
20
+ self.encodings = []
21
+ self.symbol_indices = []
22
+ bulk_dataset = pairs[:3000]
23
+ for entry in bulk_dataset:
24
+
25
+ if len(entry[0]) < MAX_SEQUENCE_LENGTH:
26
+ self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
27
+ elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
28
+ self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
29
+ else:
30
+ self.encodings.append([BOS_IDX_SRC] + entry[0][: - 1])
31
+
32
+ if len(entry[1]) < MAX_SEQUENCE_LENGTH:
33
+ self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
34
+ elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
35
+ self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
36
+ else:
37
+ self.symbol_indices.append([BOS_IDX_TGT] + entry[1][: - 1])
38
+ def __len__(self):
39
+ return len(self.encodings)
40
+
41
+ def __getitem__(self, idx):
42
+ return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])
43
+
44
+ class PositionalEncoding(nn.Module):
45
+ def __init__(self, emb_size: int, dropout: float, maxlen: int = MAX_SEQUENCE_LENGTH):
46
+ super(PositionalEncoding, self).__init__()
47
+ den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
48
+ pos = torch.arange(0, maxlen).reshape(maxlen, 1)
49
+ pos_embedding = torch.zeros((maxlen, emb_size))
50
+ pos_embedding[:, 0::2] = torch.sin(pos * den)
51
+ pos_embedding[:, 1::2] = torch.cos(pos * den)
52
+ pos_embedding = pos_embedding.unsqueeze(0).transpose(0, 1)
53
+
54
+ self.dropout = nn.Dropout(dropout)
55
+ self.register_buffer('pos_embedding', pos_embedding)
56
+
57
+ def forward(self, token_embedding: Tensor):
58
+ return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
59
+
60
+ class TokenEmbedding(nn.Module):
61
+ def __init__(self, vocab_size: int, emb_size):
62
+ super(TokenEmbedding, self).__init__()
63
+ self.embedding = nn.Embedding(vocab_size, emb_size)
64
+ self.pos_encoder = PositionalEncoding(emb_size, 0.2)
65
+ self.emb_size = emb_size
66
+
67
+ def forward(self, tokens: Tensor):
68
+ return self.pos_encoder(self.embedding(tokens.long()) * math.sqrt(self.emb_size))
69
+
70
+ class Seq2SeqTransformer(nn.Module):
71
+ def __init__(self,
72
+ num_encoder_layers: int,
73
+ num_decoder_layers: int,
74
+ emb_size: int,
75
+ nhead: int,
76
+ src_vocab_size: int,
77
+ tgt_vocab_size: int,
78
+ dim_feedforward: int = 512,
79
+ dropout: float = 0.1):
80
+ super(Seq2SeqTransformer, self).__init__()
81
+ self.transformer = Transformer(d_model=emb_size,
82
+ nhead=nhead,
83
+ num_encoder_layers=num_encoder_layers,
84
+ num_decoder_layers=num_decoder_layers,
85
+ dim_feedforward=dim_feedforward,
86
+ dropout=dropout)
87
+ self.generator = nn.Linear(emb_size, tgt_vocab_size)
88
+ self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
89
+ self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
90
+
91
+ def forward(self,
92
+ src: Tensor,
93
+ trg: Tensor,
94
+ src_mask: Tensor,
95
+ tgt_mask: Tensor,
96
+ src_padding_mask: Tensor,
97
+ tgt_padding_mask: Tensor,
98
+ memory_key_padding_mask: Tensor):
99
+ src_emb = self.src_tok_emb(src)
100
+ tgt_emb = self.tgt_tok_emb(trg)
101
+ outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
102
+ src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
103
+ return self.generator(outs)
104
+
105
+ def encode(self, src: Tensor, src_mask: Tensor):
106
+ return self.transformer.encoder(self.src_tok_emb(src), src_mask)
107
+
108
+ def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
109
+ return self.transformer.decoder(self.tgt_tok_emb(tgt), memory, tgt_mask)
110
+
111
+ def generate_square_subsequent_mask(sz):
112
+ mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
113
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
114
+ return mask
115
+
116
+
117
+ def create_mask(src, tgt):
118
+ src_seq_len = src.shape[0]
119
+ tgt_seq_len = tgt.shape[0]
120
+
121
+ tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
122
+ src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
123
+
124
+ src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
125
+ tgt_padding_mask = (tgt == PAD_IDX_TGT).transpose(0, 1)
126
+ return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
127
+
128
+ torch.manual_seed(22330)
129
+
130
+ GRAD_CLIP = 1.0
131
+
132
+ SRC_VOCAB_SIZE = 700
133
+ TGT_VOCAB_SIZE = 258
134
+ EMB_SIZE = 512
135
+ NHEAD = 16
136
+ FFN_HID_DIM = 768
137
+ BATCH_SIZE = 32
138
+ NUM_ENCODER_LAYERS = 4
139
+ NUM_DECODER_LAYERS = 4
140
+ GRAD_CLIP = 1.0
141
+ DROPOUT = 0.3 # Increased Dropout
142
+ TOTAL_STEPS = 400
143
+
144
+ model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
145
+ NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM,
146
+ dropout=DROPOUT)
147
+ model.load_state_dict(torch.load("cipher.pth", map_location=DEVICE))
148
+ model.to(DEVICE)
149
+ model.eval()
150
+ # function to generate output sequence using greedy algorithm
151
+ def greedy_decode(model, src, src_mask, max_len, start_symbol):
152
+ src = src.to(DEVICE)
153
+ src_mask = src_mask.to(DEVICE)
154
+
155
+ memory = model.encode(src, src_mask)
156
+ ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
157
+ for i in range(max_len-1):
158
+ memory = memory.to(DEVICE)
159
+ tgt_mask = (generate_square_subsequent_mask(ys.size(0))
160
+ .type(torch.bool)).to(DEVICE)
161
+ out = model.decode(ys, memory, tgt_mask)
162
+ out = out.transpose(0, 1)
163
+ prob = model.generator(out[:, -1])
164
+ _, next_word = torch.max(prob, dim=1)
165
+ next_word = next_word.item()
166
+
167
+ ys = torch.cat([ys,
168
+ torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
169
+ return ys
170
+
171
+ from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
172
+ from preprocess_dataset import get_frequency_ranks, get_proximity_array
173
+ # actual function to translate input sentence into target language
174
+ def translate(model: torch.nn.Module, src_sentence: str):
175
+ model.eval()
176
+ symbols = load_or_save_symbols([])
177
+ rule = substitution_cipher(symbols, 1337)
178
+ encodings, _ = encode_text_with_indices(rule, symbols, src_sentence)
179
+ encodings = torch.tensor(encodings[:20]).view(-1, 1)
180
+ print(encodings.shape)
181
+ num_tokens = encodings.shape[0]
182
+ src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
183
+ tgt_tokens = greedy_decode(
184
+ model, encodings, src_mask, max_len=MAX_SEQUENCE_LENGTH, start_symbol=256).flatten()
185
+ tokens = tgt_tokens.cpu().numpy()[1:]
186
+ for i, x in enumerate(tokens):
187
+ if x > 255:
188
+ tokens[i] = 0
189
+ print(tokens)
190
+ return (''.join([symbols[x] for x in tokens]))
191
+
192
+
193
+ print(translate(model, """Bien sûr ! Voici un texte plus long en français: La beauté de la nature réside dans sa diversité et sa capacité à émerveiller à chaque saison. Chaque paysage, qu’il s’agisse de montagnes imposantes, de forêts mystérieuses ou de rivières sinueuses, possède une âme et une histoire à raconter. Lorsqu’on s’aventure au cœur d’une forêt, l’air frais, imprégné des parfums de bois et de terre humide, nous invite à ralentir et à savourer l’instant. Les feuilles qui bruissent sous nos pas, les oiseaux qui chantent et la lumière qui filtre à travers les arbres créent une atmosphère presque magique, propice à la contemplation. Au printemps, la nature se réveille lentement, offrant un spectacle de couleurs éclatantes : des fleurs qui éclorent, des bourgeons qui germent, des champs qui se parent de verts éclatants. L'été, quant à lui, emporte tout dans un tourbillon de chaleur, de lumière et de vie. Les journées longues et ensoleillées sont idéales pour profiter des bienfaits du plein air, que ce soit à la mer, à la montagne ou simplement dans le jardin. L’automne, avec ses nuances orangées et dorées, est une invitation à la réflexion et à la tranquillité. Les feuilles tombent en tourbillonnant, créant des tapis colorés qui habillent la terre. Puis vient l’hiver, avec son froid piquant et la neige qui transforme le monde en un paysage féerique, silencieux et apaisant. Au-delà de la beauté visuelle, la nature nous enseigne aussi l’humilité et la résilience. Elle nous rappelle que tout est en perpétuel mouvement et que chaque cycle a sa raison d’être. Nous, humains, sommes un petit maillon dans cet écosystème complexe, et il est de notre responsabilité de préserver ce précieux équilibre. En prenant soin de notre environnement, nous garantissons non seulement la survie des espèces, mais aussi notre propre bien-être. La nature, avec sa sagesse silencieuse, continue de nous offrir des leçons de vie, chaque jour, sous nos yeux émerveillés."""))
invariants.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cipher_8bit import *
2
+ from french_dataset import get_full_dataset
3
+ import json
4
+ import pickle
5
+
6
+ def get_pattern_ranks(pattern_frequency_dict):
7
+ sorted_items = sorted(pattern_frequency_dict.items(), key=lambda x: x[1], reverse=True)
8
+
9
+ # Initialize a new dictionary for ranks
10
+ ranked_dict = {}
11
+
12
+ # Assign ranks (starting from 1)
13
+ rank = 1
14
+ for key, value in sorted_items:
15
+ ranked_dict[key] = rank
16
+ rank += 1
17
+ return ranked_dict
18
+
19
+ def unique_pattern_identifiers(symbol_index_sequences,save=False, name="data"):
20
+ freq_dict = {}
21
+ i = 0
22
+ if os.path.exists(name+".json") and save:
23
+ # Load the existing JSON data into a dictionary
24
+ with open(name+".json", "r") as json_file:
25
+ loaded_dict = json.load(json_file)
26
+ return get_pattern_ranks(loaded_dict)
27
+
28
+ for sequence in symbol_index_sequences:
29
+ raw_pattern_data = find_patterns_and_indices(sequence, remove_subsets=False)
30
+ for pattern in raw_pattern_data:
31
+ key = "-".join(map(str, pattern[0]))
32
+ if key in freq_dict:
33
+ freq_dict[key] += len(pattern[1])
34
+ else:
35
+ freq_dict[key] = len(pattern[1])
36
+ i += 1
37
+ if save:
38
+ with open(name+".json", "w") as json_file:
39
+ json.dump(freq_dict, json_file, indent=4)
40
+
41
+ return get_pattern_ranks(freq_dict)
42
+
43
+ def get_data_pairs(full_text):
44
+ if os.path.exists("data_pairs.pkl"):
45
+ with open("data_pairs.pkl", 'rb') as f:
46
+ print("Loading training pairs from pickle file...")
47
+ return pickle.load(f)
48
+ text_chunks = []
49
+ chunk_len = 1500
50
+ i=0
51
+ print("starting chunking text")
52
+ while i * chunk_len < len(full_text) - chunk_len - 1:
53
+ i += 1
54
+ sample_text = full_text[(i - 1) * chunk_len: i * chunk_len - 1]
55
+ text_chunks.append(sample_text)
56
+ symbol_index_sequences = []
57
+ symbols = ['b', 'j', '\r', 'J', '”', ')', 'Â', 'É', 'ê', '5', 't', '9', 'Y', '%', 'N', 'B', 'V', '\ufeff', 'Ê', '?', '’', 'i', ':', 's', 'C', 'â', 'ï', 'W', 'y', 'p', 'D', '—', '«', 'º', 'A', '3', 'n', '0', 'q', '4', 'e', 'T', 'È', '$', 'U', 'v', '»', 'l', 'P', 'X', 'Z', 'À', 'ç', 'u', '…', 'î', 'L', 'k', 'E', 'R', '2', '_', '8', 'é', 'O', 'Î', '‘', 'a', 'F', 'H', 'c', '[', '(', "'", 'è', 'I', '/', '!', ' ', '°', 'S', '•', '#', 'x', 'à', 'g', '*', 'Q', 'w', '1', 'û', '7', 'G', 'm', '™', 'K', 'z', '\n', 'o', 'ù', ',', 'r', ']', '.', 'M', 'Ç', '“', 'h', '-', 'f', 'ë', '6', ';', 'd', 'ô', 'e ', 's ', 't ', 'es', ' d', '\r\n', 'en', 'qu', ' l', 're', ' p', 'de', 'le', 'nt', 'on', ' c', ', ', ' e', 'ou', ' q', ' s', 'n ', 'ue', 'an', 'te', ' a', 'ai', 'se', 'it', 'me', 'is', 'oi', 'r ', 'er', ' m', 'ce', 'ne', 'et', 'in', 'ns', ' n', 'ur', 'i ', 'a ', 'eu', 'co', 'tr', 'la', 'ar', 'ie', 'ui', 'us', 'ut', 'il', ' t', 'pa', 'au', 'el', 'ti', 'st', 'un', 'em', 'ra', 'e,', 'so', 'or', 'l ', ' f', 'll', 'nd', ' j', 'si', 'ir', 'e\r', 'ss', 'u ', 'po', 'ro', 'ri', 'pr', 's,', 'ma', ' v', ' i', 'di', ' r', 'vo', 'pe', 'to', 'ch', '. ', 've', 'nc', 'om', ' o', 'je', 'no', 'rt', 'à ', 'lu', "'e", 'mo', 'ta', 'as', 'at', 'io', 's\r', 'sa', "u'", 'av', 'os', ' à', ' u', "l'", "'a", 'rs', 'pl', 'é ', '; ', 'ho', 'té', 'ét', 'fa', 'da', 'li', 'su', 't\r', 'ée', 'ré', 'dé', 'ec', 'nn', 'mm', "'i", 'ca', 'uv', '\n\r', 'id', ' b', 'ni', 'bl']
58
+ symbols = load_or_save_symbols(symbols)
59
+ substitution_rule = substitution_cipher(symbols, 1337)
60
+
61
+ def invariate_sequence(sample, ids, vocab_size):
62
+ fill_in = []
63
+ p = find_patterns_and_indices(sample)
64
+ u = find_unique_singles(sample)
65
+ for pattern in p:
66
+ value = "-".join(map(str, pattern[0]))
67
+ for index in pattern[1]:
68
+ fill_in.append([index, ids[value], len(pattern[0])])
69
+ for unique in u:
70
+ fill_in.append([unique[1][0], vocab_size + 1, 0])
71
+ fill_in.sort(key=lambda x: x[0])
72
+ total_list = [0] * 1024
73
+ i=0
74
+ tally = 0
75
+ pattern_count = 0
76
+ while pattern_count < len(fill_in):
77
+ if tally != fill_in[pattern_count][0]:
78
+ total_list[i] = 1
79
+ i+=1
80
+ tally+=1
81
+ continue
82
+ if fill_in[pattern_count][2] == 0:
83
+ total_list[i] = 0
84
+ i+=1
85
+ tally+=1
86
+ pattern_count +=1
87
+ continue
88
+
89
+ if fill_in[pattern_count][2] == 1:
90
+ total_list[i+1] = fill_in[pattern_count][1] + 5
91
+ i+=1
92
+ else:
93
+ total_list[i] = fill_in[pattern_count][2]
94
+ total_list[i+1] = fill_in[pattern_count][1] + 5
95
+ i+=2
96
+
97
+ tally += fill_in[pattern_count][2]
98
+ pattern_count += 1
99
+
100
+ total_list = total_list[:i]
101
+ return total_list
102
+
103
+ dataset = []
104
+ print(len(text_chunks))
105
+ i5 = 0
106
+ for text_i in range(len(text_chunks)):
107
+ i5 += 1
108
+ if i5 % 100 == 0:
109
+ print(i5)
110
+ sample_encodings, sample_indices = encode_text_with_indices(substitution_rule, symbols, text_chunks[text_i])
111
+ sample_encodings = sample_encodings[:512]
112
+ sample_indices = sample_indices[:512]
113
+ if sample_indices.count(256) > 0:
114
+ continue
115
+
116
+ encodings_identifiers = unique_pattern_identifiers([sample_encodings], False)
117
+ encodings_vocab = len(encodings_identifiers.items())
118
+
119
+ encoding_list = invariate_sequence(sample_encodings, encodings_identifiers, encodings_vocab)
120
+ dataset.append([encoding_list, sample_indices])
121
+ with open("data_pairs.pkl", 'wb') as f:
122
+ pickle.dump(dataset, f)
123
+ return dataset
124
+
125
+
126
+
127
+ def filter_subset_pairs(pairs):
128
+
129
+ def is_subset_pair(pair1, pair2):
130
+ first1, second1 = set(pair1[0]), set(pair1[1])
131
+ first2, second2 = set(pair2[0]), set(pair2[1])
132
+ return (first1.issubset(first2) and second1.issubset(second2) and
133
+ (len(first1) < len(first2) or len(second1) < len(second2)))
134
+
135
+ result = pairs.copy()
136
+ i = len(result) - 1
137
+
138
+ while i >= 0:
139
+ should_remove = False
140
+ for j, pair2 in enumerate(result):
141
+ if i != j and is_subset_pair(result[i], pair2):
142
+ should_remove = True
143
+ break
144
+ if should_remove:
145
+ result.pop(i)
146
+ i -= 1
147
+
148
+ return result
149
+
150
+ def find_patterns_and_indices(sequence, remove_subsets=True):
151
+ """
152
+ Find all repeating subsequences in a sequence and their indices.
153
+ Excludes indices of subsequences when they are part of a larger repeating subsequence.
154
+
155
+ Args:
156
+ sequence (list): Input sequence of numbers
157
+
158
+ Returns:
159
+ list: List of [subsequence, indices] pairs for repeating subsequences
160
+ """
161
+ n = len(sequence)
162
+ result = []
163
+
164
+ # Helper function to convert list to tuple for hashability
165
+ def to_tuple(lst):
166
+ return tuple(lst)
167
+
168
+ # Find all possible subsequences and their indices
169
+ subsequence_indices = {}
170
+ for length in range(1, 5): # Start from length 2
171
+ for i in range(n - length + 1):
172
+ subseq = to_tuple(sequence[i:i + length])
173
+ if subseq not in subsequence_indices:
174
+ subsequence_indices[subseq] = []
175
+ subsequence_indices[subseq].append(i)
176
+
177
+ # Filter out non-repeating subsequences
178
+ repeating_subsequences = {
179
+ subseq: indices
180
+ for subseq, indices in subsequence_indices.items()
181
+ if len(indices) > 1
182
+ }
183
+
184
+ # Sort subsequences by length (longest first)
185
+ sorted_subsequences = sorted(
186
+ repeating_subsequences.items(),
187
+ key=lambda x: len(x[0]),
188
+ reverse=True
189
+ )
190
+
191
+ # Keep track of used indices
192
+ used_indices = set()
193
+
194
+ # Process subsequences from longest to shortest
195
+ for subseq, indices in sorted_subsequences:
196
+ # Filter out indices that are already part of longer subsequences
197
+ valid_indices = []
198
+ if remove_subsets:
199
+ for idx in indices:
200
+ # Check if any position in this occurrence overlaps with used indices
201
+ overlap = False
202
+ for pos in range(idx, idx + len(subseq)):
203
+ if pos in used_indices:
204
+ overlap = True
205
+ break
206
+ if not overlap:
207
+ valid_indices.append(idx)
208
+ # Mark all positions in this occurrence as used
209
+ for pos in range(idx, idx + len(subseq)):
210
+ used_indices.add(pos)
211
+ else:
212
+ valid_indices = indices
213
+
214
+ # Only add subsequence if it still has multiple valid occurrences
215
+ if len(valid_indices) > 1:
216
+ result.append([list(subseq), valid_indices])
217
+
218
+ return result
219
+
220
+ def find_unique_singles(sequence):
221
+ arr = []
222
+ for i, element in enumerate(sequence):
223
+ count = sequence.count(element)
224
+ if count == 1:
225
+ arr.append([[element], [i]])
226
+
227
+ return arr
preprocess_dataset.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cipher_8bit import *
2
+ from french_dataset import get_full_dataset
3
+
4
+ def get_frequency_ranks(encodings, symbols, sequence_len):
5
+ freq_ranks_dict = [0] * len(symbols)
6
+ encodings = encodings[:sequence_len]
7
+ for encoding in encodings:
8
+ freq_ranks_dict[encoding] += 1
9
+ freq_ranks = [0] * (sequence_len)
10
+
11
+ for i in range(len(encodings)):
12
+ freq_ranks[i] = freq_ranks_dict[encodings[i]]
13
+ return freq_ranks
14
+
15
+ def get_proximity_array(encodings, sequence_len):
16
+ distances = [0] * (sequence_len)
17
+ encodings = encodings[:sequence_len]
18
+ for i, encoding in enumerate(encodings):
19
+ try:
20
+ last_idx = encodings.index(encoding, 0, i)
21
+ distances[i] = (i - last_idx)
22
+ except ValueError:
23
+ # If the encoding is not found in the indices, set the distance to 0
24
+ distances[i] = 0
25
+ return distances
26
+
27
+ def preprocess_text(sequence_len=256):
28
+ full_text = get_full_dataset()
29
+ symbols = get_symbols(full_text, 256)
30
+ symbols = load_or_save_symbols(symbols)
31
+ substitution_rule = substitution_cipher(symbols, 1337)
32
+
33
+ i = 0
34
+ raw_length = sequence_len * 2 # Overshoot so when it encodes it takes atleast sequence_len
35
+ processed_data = []
36
+ while i * raw_length < len(full_text) - raw_length:
37
+ i += 1
38
+ sample_text = full_text[(i - 1) * raw_length: i * raw_length - 1]
39
+ encodings_array, indices = encode_text_with_indices(substitution_rule, symbols, sample_text)
40
+ if len(encodings_array) > sequence_len: encodings_array = encodings_array[:sequence_len]
41
+ if len(indices) > sequence_len: indices = indices[:sequence_len]
42
+
43
+ ranks = get_frequency_ranks(encodings_array, symbols, sequence_len)
44
+ distances = get_proximity_array(encodings_array, sequence_len)
45
+ processed_data.append([encodings_array, distances, indices])
46
+ return processed_data
47
+
symbols.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5f430187556b38b219487b3e44a9d89cf84fc324807897b2579bda60a12badf
3
+ size 1230
test2.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from preprocess_dataset import preprocess_text
8
+ import bitsandbytes as bnb
9
+ from invariants import get_data_pairs
10
+ from french_dataset import get_full_dataset
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
13
+
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ MAX_SEQUENCE_LENGTH = 512
17
+ PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 257, 699, 258
18
+
19
+ training_pairs = get_data_pairs(get_full_dataset())
20
+ def create_mask(src):
21
+ src_seq_len = src.shape[0]
22
+
23
+ src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
24
+
25
+ src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
26
+ return src_padding_mask
27
+ class CipherDataset(Dataset):
28
+ def __init__(self, pairs):
29
+ self.encodings = []
30
+ self.symbol_indices = []
31
+ bulk_dataset = pairs
32
+ for entry in bulk_dataset:
33
+
34
+ if len(entry[0]) < MAX_SEQUENCE_LENGTH:
35
+ self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
36
+ elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
37
+ self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
38
+ else:
39
+ self.encodings.append([BOS_IDX_SRC] + entry[0][:-1])
40
+
41
+ if len(entry[1]) < MAX_SEQUENCE_LENGTH:
42
+ self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
43
+ elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
44
+ self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
45
+ else:
46
+ self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:-1])
47
+
48
+ def __len__(self):
49
+ return len(self.encodings)
50
+
51
+ def __getitem__(self, idx):
52
+ return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])
53
+
54
+
55
+ # Wrap data in the simplest possible way to enable PyTorch data fetching
56
+ # https://pytorch.org/docs/stable/data.html
57
+
58
+ BATCH_SIZE = 64
59
+ TRAIN_FRAC = 0.995
60
+
61
+ dataset = CipherDataset(training_pairs)
62
+ N = len(training_pairs)
63
+ print(N)
64
+ S = 512
65
+ C = 700
66
+ # Split into train and val
67
+ num_train = int(N * TRAIN_FRAC)
68
+ num_val = N - num_train
69
+ data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))
70
+
71
+ dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
72
+ dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)
73
+
74
+ # Sample batch
75
+ x, y = next(iter(dataloader_train))
76
+ class PositionalEncoding(nn.Module):
77
+ """
78
+ Classic Attention-is-all-you-need positional encoding.
79
+ From PyTorch docs.
80
+ """
81
+
82
+ def __init__(self, d_model, dropout=0.1, max_len=512):
83
+ super(PositionalEncoding, self).__init__()
84
+ self.dropout = nn.Dropout(p=dropout)
85
+
86
+ pe = torch.zeros(max_len, d_model)
87
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
88
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
89
+ pe[:, 0::2] = torch.sin(position * div_term)
90
+ pe[:, 1::2] = torch.cos(position * div_term)
91
+ pe = pe.unsqueeze(0).transpose(0, 1)
92
+ self.register_buffer('pe', pe)
93
+
94
+ def forward(self, x):
95
+ x = x + self.pe[:x.size(0), :]
96
+ return self.dropout(x)
97
+
98
+
99
+ def generate_square_subsequent_mask(size: int):
100
+ """Generate a triangular (size, size) mask. From PyTorch docs."""
101
+ mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
102
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
103
+ return mask
104
+
105
+
106
+ class Transformer(nn.Module):
107
+ """
108
+ Classic Transformer that both encodes and decodes.
109
+
110
+ Prediction-time inference is done greedily.
111
+
112
+ NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
113
+ """
114
+
115
+ def __init__(self, num_classes: int, max_output_length: int, dim: int = 192):
116
+ super().__init__()
117
+
118
+ # Parameters
119
+ self.dim = dim
120
+ self.max_output_length = max_output_length
121
+ nhead = 16
122
+ num_layers = 8
123
+ dim_feedforward = dim
124
+
125
+ # Encoder part
126
+ self.x_embedding = nn.Embedding(700, dim)
127
+ self.y_embedding = nn.Embedding(259, dim)
128
+ self.pos_encoder = PositionalEncoding(d_model=self.dim)
129
+ self.transformer_encoder = nn.TransformerEncoder(
130
+ encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
131
+ num_layers=num_layers
132
+ )
133
+
134
+ # Decoder part
135
+
136
+ self.y_mask = generate_square_subsequent_mask(self.max_output_length)
137
+ self.transformer_decoder = nn.TransformerDecoder(
138
+ decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
139
+ num_layers=num_layers
140
+ )
141
+ self.fc = nn.Linear(self.dim, 259)
142
+
143
+ # It is empirically important to initialize weights properly
144
+ self.init_weights()
145
+
146
+ def init_weights(self):
147
+ initrange = 0.1
148
+ self.x_embedding.weight.data.uniform_(-initrange, initrange)
149
+ self.y_embedding.weight.data.uniform_(-initrange, initrange)
150
+ self.fc.bias.data.zero_()
151
+ self.fc.weight.data.uniform_(-initrange, initrange)
152
+
153
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
154
+ """
155
+ Input
156
+ x: (B, Sx) with elements in (0, C) where C is num_classes
157
+ y: (B, Sy) with elements in (0, C) where C is num_classes
158
+ Output
159
+ (B, C, Sy) logits
160
+ """
161
+ x_pad_mask= create_mask(x)
162
+
163
+ encoded_x = self.encode(x, x_pad_mask) # (Sx, B, E)
164
+ output = self.decode(y, encoded_x) # (Sy, B, C)
165
+ return output.permute(1, 2, 0) # (B, C, Sy)
166
+
167
+ def encode(self, x: torch.Tensor, x_pad_mask: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Input
170
+ x: (B, Sx) with elements in (0, C) where C is num_classes
171
+ Output
172
+ (Sx, B, E) embedding
173
+ """
174
+ x = x.permute(1, 0) # (Sx, B, E)
175
+ x = self.x_embedding(x) * math.sqrt(self.dim) # (Sx, B, E)
176
+ x = self.pos_encoder(x) # (Sx, B, E)
177
+ x = self.transformer_encoder(x, None, x_pad_mask.transpose(0,1)) # (Sx, B, E)
178
+ return x
179
+
180
+ def decode(self, y: torch.Tensor, encoded_x: torch.Tensor) -> torch.Tensor:
181
+ """
182
+ Input
183
+ encoded_x: (Sx, B, E)
184
+ y: (B, Sy) with elements in (0, C) where C is num_classes
185
+ Output
186
+ (Sy, B, C) logits
187
+ """
188
+ y = y.permute(1, 0) # (Sy, B)
189
+ y = self.y_embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
190
+ y = self.pos_encoder(y) # (Sy, B, E)
191
+ Sy = y.shape[0]
192
+ y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x) # (Sy, Sy)
193
+ output = self.transformer_decoder(y, encoded_x, y_mask) # (Sy, B, E)
194
+ output = self.fc(output) # (Sy, B, C)
195
+ return output
196
+
197
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
198
+ """
199
+ Method to use at inference time. Predict y from x one token at a time. This method is greedy
200
+ decoding. Beam search can be used instead for a potential accuracy boost.
201
+
202
+ Input
203
+ x: (B, Sx) with elements in (0, C) where C is num_classes
204
+ Output
205
+ (B, C, Sy) logits
206
+ """
207
+ x_pad_mask = create_mask(x)
208
+
209
+ encoded_x = self.encode(x, x_pad_mask)
210
+
211
+ output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() # (B, max_length)
212
+ output_tokens[:, 0] = BOS_IDX_TGT # Set start token
213
+ for Sy in range(1, self.max_output_length):
214
+ y = output_tokens[:, :Sy] # (B, Sy)
215
+ output = self.decode(y, encoded_x) # (Sy, B, C)
216
+ output = torch.argmax(output, dim=-1) # (Sy, B)
217
+ output_tokens[:, Sy] = output[-1:] # Set the last output token
218
+ return output_tokens
219
+
220
+
221
+ class LitModel(pl.LightningModule):
222
+ """Simple PyTorch-Lightning model to train our Transformer."""
223
+
224
+ def __init__(self, model):
225
+ super().__init__()
226
+ self.model = model
227
+ self.loss = torch.nn.CrossEntropyLoss(label_smoothing=0.2, reduction='mean')
228
+
229
+ def training_step(self, batch, batch_ind):
230
+ x, y = batch
231
+ # Teacher forcing: model gets input up to the last character,
232
+ # while ground truth is from the second character onward.
233
+ logits = self.model(x, y[:, :-1])
234
+ loss = self.loss(logits, y[:, 1:])
235
+ self.log("train_loss", loss)
236
+ return loss
237
+
238
+ def validation_step(self, batch, batch_ind):
239
+ x, y = batch
240
+ logits = self.model(x, y[:, :-1])
241
+ loss = self.loss(logits, y[:, 1:])
242
+ self.log("val_loss", loss, prog_bar=True)
243
+
244
+
245
+ def configure_optimizers(self):
246
+ return bnb.optim.AdamW8bit(self.parameters(), lr=0.0005, betas=(0.9, 0.99), eps=1e-8, weight_decay=0.01)
247
+
248
+
249
+ # We can see that the decoding works correctly
250
+
251
+ x, y = next(iter(dataloader_val))
252
+
253
+ model = Transformer(num_classes=700, max_output_length=y.shape[1])
254
+ lit_model = LitModel(model)
255
+ trainer = pl.Trainer(max_epochs=1)
256
+ trainer.fit(lit_model, dataloader_train, dataloader_val)
257
+ torch.save(model.state_dict(), "cipher.pth")
258
+
259
+ print('Input:', x[:1])
260
+ pred = lit_model.model.predict(x[:1])
261
+ print('Truth/Pred:')
262
+ tokens = torch.cat((y[:1], pred)).cpu().numpy()
263
+ symbols = load_or_save_symbols([])
264
+ print(tokens)
265
+ print(''.join([symbols[x] if x < 256 else "#" for x in tokens[0][1:]]))
266
+ print(''.join([symbols[x] if x < 256 else "#" for x in tokens[1][1:]]))
267
+