Upload 8 files
Browse files- cipher_8bit.py +75 -0
- data_pairs.pkl +3 -0
- french_dataset.py +53 -0
- infer.py +193 -0
- invariants.py +227 -0
- preprocess_dataset.py +47 -0
- symbols.pkl +3 -0
- test2.py +267 -0
cipher_8bit.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import random
|
| 4 |
+
from collections import Counter
|
| 5 |
+
|
| 6 |
+
def cut_string_into_pairs(text_corpus):
|
| 7 |
+
pairs = []
|
| 8 |
+
for i in range(0, len(text_corpus) - 1, 2):
|
| 9 |
+
pairs.append(text_corpus[i:i + 2])
|
| 10 |
+
if len(text_corpus) % 2 != 0:
|
| 11 |
+
pairs.append(text_corpus[-1] + '_')
|
| 12 |
+
return pairs
|
| 13 |
+
|
| 14 |
+
def get_symbols(text_corpus, max_characters=256):
|
| 15 |
+
# Get all single unique characters, then fill in the rest of the symbol spots with the most common character pairs
|
| 16 |
+
single_characters = list(set(list(text_corpus)))
|
| 17 |
+
pairs = [item for item, _ in Counter(cut_string_into_pairs(text_corpus)).most_common(256 - len(single_characters))]
|
| 18 |
+
return single_characters + pairs
|
| 19 |
+
|
| 20 |
+
def substitution_cipher(symbols, random_seed):
|
| 21 |
+
random.seed(random_seed)
|
| 22 |
+
# Make a randomly ordered range from 0-255 (8-bits)
|
| 23 |
+
integer_encodings = random.sample(list(range(len(symbols))), len(symbols))
|
| 24 |
+
substitution_rule = dict({})
|
| 25 |
+
# Map every symbol to a unique encoding
|
| 26 |
+
for idx, symbol in enumerate(symbols):
|
| 27 |
+
encoding = integer_encodings[idx] # Get the random encoding for the symbol
|
| 28 |
+
substitution_rule[symbol] = encoding # Store the symbol as the key, encoding as the value
|
| 29 |
+
return substitution_rule
|
| 30 |
+
|
| 31 |
+
def encode_text_with_indices(rule, symbols, text):
|
| 32 |
+
encoded_text = []
|
| 33 |
+
indices = []
|
| 34 |
+
i = 0
|
| 35 |
+
|
| 36 |
+
# Create a reverse mapping from symbols to their indices
|
| 37 |
+
index_dict = dict(zip(symbols, range(len(symbols))))
|
| 38 |
+
|
| 39 |
+
while i < len(text):
|
| 40 |
+
# Check for pairs
|
| 41 |
+
if i + 1 < len(text):
|
| 42 |
+
pair = text[i] + text[i + 1]
|
| 43 |
+
# Check if the pair exists in the rule
|
| 44 |
+
if pair in rule:
|
| 45 |
+
encoding = rule[pair] # Get the encoding for the pair
|
| 46 |
+
encoded_text.append(encoding)
|
| 47 |
+
indices.append(index_dict[pair]) # Get the index of the symbol
|
| 48 |
+
i += 2 # Skip the two characters used in the pair
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# Single character substitution
|
| 52 |
+
if text[i] in rule:
|
| 53 |
+
encoding = rule[text[i]]
|
| 54 |
+
encoded_text.append(encoding)
|
| 55 |
+
indices.append(index_dict[text[i]]) # Get the index of the symbol
|
| 56 |
+
else:
|
| 57 |
+
# If the character doesn't exist in the rule, keep it as-is
|
| 58 |
+
encoded_text.append(256)
|
| 59 |
+
indices.append(256) # Use -1 or some other value to indicate no encoding for this character
|
| 60 |
+
|
| 61 |
+
i += 1
|
| 62 |
+
|
| 63 |
+
return encoded_text, indices
|
| 64 |
+
|
| 65 |
+
# Function to load or save the symbols as a pickle
|
| 66 |
+
def load_or_save_symbols(symbols, pickle_file_path="symbols.pkl"):
|
| 67 |
+
if os.path.exists(pickle_file_path):
|
| 68 |
+
with open(pickle_file_path, 'rb') as f:
|
| 69 |
+
print("Loading symbols from pickle file...")
|
| 70 |
+
return pickle.load(f)
|
| 71 |
+
else:
|
| 72 |
+
print("Pickle file not found. Saving symbols...")
|
| 73 |
+
with open(pickle_file_path, 'wb') as f:
|
| 74 |
+
pickle.dump(symbols, f)
|
| 75 |
+
return symbols
|
data_pairs.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1f11c77420da01df9611b237623d76640a0871a8f2a25def2f614f2b5233ec1
|
| 3 |
+
size 1964900268
|
french_dataset.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
def clean_text(combined_text):
|
| 6 |
+
# Regex pattern to keep only English/French letters, punctuation, and newlines
|
| 7 |
+
allowed_chars_pattern = r"[A-Za-zÀ-ÿéèêàâçîôùûïëô\s\.,;!?()\"'\-\n]"
|
| 8 |
+
|
| 9 |
+
# Keep only the characters that match the allowed pattern
|
| 10 |
+
cleaned_text = ''.join(re.findall(allowed_chars_pattern, combined_text))
|
| 11 |
+
|
| 12 |
+
return cleaned_text
|
| 13 |
+
|
| 14 |
+
def fetch_gutenberg_text(book_id):
|
| 15 |
+
url = f"http://www.gutenberg.org/ebooks/{book_id}.txt.utf-8"
|
| 16 |
+
response = requests.get(url)
|
| 17 |
+
if response.status_code == 200:
|
| 18 |
+
return response.text
|
| 19 |
+
else:
|
| 20 |
+
print(f"Failed to fetch book ID {book_id}")
|
| 21 |
+
return ""
|
| 22 |
+
|
| 23 |
+
def get_full_dataset(file_path: str = "full_dataset.txt") -> str:
|
| 24 |
+
# List of French text IDs from Project Gutenberg
|
| 25 |
+
french_text_ids = ['796', '797', '798', '799', '800', '801', '802', '803', '1256', '1339', '1910', '2419', '2650', '2682', '2998', '2999', '3000', '3456', '3644', '3645', '4548', '4559', '4561', '4562', '4563', '4564', '4565', '4566', '4567', '4568', '4569', '4570', '4647', '4648', '4649', '4650', '4651', '4688', '4708', '4717', '4718', '4740', '4741', '4771', '4772', '4785', '4791', '4933', '4935', '4936', '4968', '5081', '5082', '5095', '5096', '5097', '5104', '5105', '5126', '5130', '5138', '5147', '5154', '5158', '5178', '5250', '5258', '5318', '5423', '5644', '5711', '5781', '5892', '6099', '6309', '6318', '6319', '6377', '6470', '6484', '6497', '6501', '6558', '6691', '6739', '6838', '6966', '6994', '7012', '7173', '7263', '7268', '7442', '7461', '7462', '7770', '7771', '7772', '7809', '7812', '7818', '7854', '8074', '8173', '8174', '8175', '8186', '8416', '8453', '8454', '8490', '8520', '8524', '8541', '8560', '8561', '8563', '8650', '8692', '8693', '8712', '8719', '8739', '8822', '8863', '8864', '8876', '8907', '8946', '9053', '9261', '9262', '9453', '9637', '9638', '9639', '9643', '9644', '9645', '9655', '9818', '9824', '9891', '9892', '9893', '9894', '9945', '9976', '10053', '10061', '10160', '10263', '10289', '10290', '10346', '10384', '10385', '10442', '10604', '10678', '10680', '10682', '10683', '10685', '10687', '10689', '10697', '10746', '10764', '10768', '10774', '10775', '10824', '10841', '10906', '10953', '10982', '11035', '11036', '11037', '11038', '11040', '11042', '11046', '11048', '11049', '11131', '11132', '11175', '11176', '11178', '11199', '11300', '11301', '11380', '11423', '11434', '11450', '11453', '11484', '11494', '11495', '11586', '11588', '11589', '11590', '11596', '11597', '11621', '11622', '11645', '11646', '11650', '11678', '11714', '11744', '11747', '11766', '11767', '11769', '11770', '11893', '11905', '11927', '11928', '11964', '12005', '12011', '12065', '12072', '12080', '12105', '12120', '12174', '12230', '12246', '12247', '12250', '12251', '12258', '12271', '12284', '12289', '12295', '12301', '12331', '12332', '12338', '12356', '12365', '12367', '12399', '12401', '12437', '12447', '12448', '12451', '12459', '12487', '12488', '12489', '12533', '12534', '12562', '12566', '12602', '12603', '12620', '12646', '12665', '12666', '12726', '12727', '12749', '12751', '12752', '12782', '12783', '12812', '12829', '12837', '12862', '12865', '12869', '12889', '12893', '12949', '12950', '12969', '12979', '12993', '12999', '13013', '13016', '13024', '13025', '13027', '13036', '13038', '13059', '13070', '13096', '13122', '13149', '13187', '13189', '13190', '13192', '13198', '13207', '13219', '13221', '13230', '13231', '13247', '13256', '13258', '13263', '13284', '13299', '13303', '13336', '13339', '13374', '13380', '13383', '13385', '13400', '13431', '13456', '13475', '13478', '13490', '13523', '13524', '13525', '13557', '13562', '13592', '13594', '13598', '13607', '13622', '13628', '13629', '13653', '13654', '13668', '13671', '13676', '13703', '13704', '13727', '13734', '13735', '13737', '13743', '13744', '13765', '13771', '13772', '13792', '13793', '13794', '13795', '13798', '13804', '13807', '13808', '13818', '13819', '13825', '13834', '13837', '13838', '13839', '13846', '13848', '13855', '13856', '13857', '13861', '13862', '13863', '13866', '13868', '13875', '13892', '13914', '13917', '13938', '13947', '13948', '13949', '13950', '13951', '13952', '13965', '13981', '14030', '14038', '14059', '14069', '14071', '14113', '14151', '14155', '14156', '14157', '14158', '14159', '14162', '14163', '14192', '14247', '14251', '14258', '14259', '14285', '14286', '14287', '14288', '14309', '14310', '14343', '14372', '14397', '14398', '14399', '14404', '14512', '14536', '14537', '14538', '14539', '14541', '14564', '14609', '14677', '14683', '14688', '14692', '14693', '14702', '14703', '14704', '14705', '14713', '14720', '14751', '14788', '14789', '14790', '14791', '14792', '14793', '14799', '14803', '14804', '14805', '14806', '14810', '14820', '14827', '14828', '14905', '14911', '14912', '14913', '14918', '15032', '15057', '15058', '15059', '15060', '15071', '15075', '15107', '15112', '15113', '15146', '15150', '15152', '15203', '15208', '15212', '15226', '15235', '15239', '15286', '15287', '15288', '15295', '15296', '15297', '15303', '15305', '15312', '15324', '15361', '15371', '15372', '15375', '15388', '15397', '15433', '15434', '15458', '15459', '15462', '15463', '15543', '15554', '15555', '15556', '15557', '15558', '15574', '15579', '15584', '15589', '15593', '15598', '15626', '15635', '15642', '15645', '15646', '15686', '15732', '15739', '15804', '15805', '15811', '15815', '15816', '15823', '15844', '15846', '15847', '15848', '15849', '15871', '15882', '15885', '15907', '15942', '15943', '16020', '16021', '16022', '16023', '16066', '16067', '16128', '16210', '16234', '16235', '16236', '16237', '16238', '16239', '16240', '16260', '16286', '16336', '16388', '16421', '16465', '16492', '16499', '16649', '16709', '16710', '16743', '16758', '16789', '16795', '16796', '16812', '16813', '16814', '16815', '16816', '16817', '16818', '16819', '16820', '16824', '16825', '16826', '16827', '16828', '16848', '16849', '16850', '16851', '16852', '16862', '16874', '16875', '16876', '16883', '16884', '16885', '16886', '16887', '16888', '16901', '16934', '16988', '16989', '17004', '17010', '17044', '17098', '17105', '17106', '17123', '17140', '17184', '17225', '17230', '17231', '17232', '17233', '17234', '17235', '17236', '17238', '17240', '17242', '17248', '17251', '17252', '17258', '17261', '17264', '17267', '17271', '17281', '17285', '17298', '17311', '17319', '17335', '17344', '17345', '17353', '17360', '17363', '17372', '17399', '17419', '17420', '17457', '17458', '17459', '17489', '17493', '17494', '17501', '17505', '17509', '17516', '17517', '17518', '17519', '17533', '17538', '17540', '17541', '17542', '17543', '17550', '17551', '17552', '17553', '17555', '17557', '17561', '17565', '17573', '17577', '17578', '17589', '17590', '17602', '17605', '17623', '17631', '17632', '17640', '17641', '17643', '17646', '17656', '17660', '17661', '17662', '17668', '17670', '17673', '17675', '17676', '17688', '17691', '17692', '17693', '17696', '17707', '17708', '17709', '17714', '17715', '17716', '17717', '17734', '17736', '17738', '17739', '17746', '17747', '17752', '17757', '17758', '17791', '17794', '17795', '17796', '17798', '17808', '17809', '17810', '17828', '17830', '17831', '17832', '17840', '17868', '17869', '17879', '17880', '17899', '17911', '17914', '17915', '17916', '17930', '17940', '17941', '17942', '17947', '17949', '17950', '17951', '17963', '17980', '17983', '17984', '17989', '17990', '17991', '17992', '18003', '18006', '18014', '18015', '18024', '18027', '18028', '18029', '18034', '18055', '18059', '18061', '18064', '18067', '18073', '18074', '18075', '18081', '18083', '18084', '18085', '18089', '18090', '18092', '18106', '18108', '18111', '18112', '18121', '18123', '18133', '18142', '18143', '18152', '18159', '18162', '18169', '18179', '18197', '18199', '18200', '18205', '18208', '18211', '18215', '18244', '18245', '18262', '18263', '18271', '18294', '18295', '18296', '18302', '18311', '18312', '18313', '18321', '18340', '18353', '18358', '18367', '18368', '18401', '18402', '18403', '18404', '18407', '18415', '18416', '18427', '18446', '18454', '18455', '18490', '18491', '18494', '18518', '18535', '18537', '18543', '18583', '18585', '18586', '18610', '18611', '18623', '18627', '18669', '18672', '18692', '18693', '18695', '18697', '18715', '18716', '18717', '18718', '18724', '18727', '18738', '18749', '18771', '18773', '18797', '18806', '18812', '18825', '18826', '18849', '18864', '18865', '18889', '18890', '18899', '18918', '18919', '18920', '18921', '18922', '18923', '18924', '18925', '18940', '18942', '18944', '18962', '18983', '18995', '18996', '19008', '19021', '19035', '19045', '19075', '19112', '19124', '19149', '19152', '19184', '19186', '19187', '19201', '19219', '19227', '19228', '19232', '19233', '19234', '19248', '19249', '19266', '19344', '19345', '19431', '19440', '19454', '19455', '19483', '19497', '19519', '19536', '19540', '19588', '19604', '19631', '19657', '19662', '19689', '19700', '19738', '19756', '19820', '19842', '19854', '19862', '19919', '19920', '19954', '19956', '19972', '19982', '19984', '19992', '20013', '20077', '20079', '20108', '20143', '20199', '20234', '20244', '20246', '20254', '20262', '20325', '20372', '20396', '20398', '20414', '20415', '20440', '20441', '20457', '20479', '20490', '20498', '20507', '20554', '20562', '20564', '20568', '20577', '20623', '20635', '20640', '20664', '20690', '20700', '20703', '20705', '20720', '20761', '20773', '20790', '20823', '20824', '20825', '20829', '20864', '20865', '20886', '20894', '20895', '20949', '20950', '20964', '20966', '20971', '20972', '20973', '20974', '21001', '21013', '21017', '21023', '21124', '21191', '21199', '21215', '21221', '21257', '21277', '21343', '21413', '21544', '21669', '21792', '21804', '21825', '21856', '21896', '21912', '21940', '21966', '22007', '22011', '22016', '22039', '22048', '22054', '22068', '22077', '22078', '22111', '22192', '22253', '22262', '22266', '22268', '22356', '22383', '22384', '22385', '22386', '22388', '22393', '22394', '22416', '22429', '22543', '22548', '22551', '22552', '22558', '22572', '22575', '22613', '22618', '22633', '22741', '22751', '22760', '22768', '22769', '22780', '22813', '22830', '22889', '22917', '22918', '22971', '23019', '23020', '23047', '23098', '23158', '23199', '23202', '23211', '23279', '23285', '23289', '23423', '23444', '23463', '23484', '23508', '23520', '23566', '23567', '23578', '23582', '23583', '23589', '23596', '23610', '23615', '23616', '23618', '23654', '23801', '23828', '23830', '23848', '23917', '23939', '23940', '23953', '23954', '24007', '24045', '24081', '24123', '24217', '24243', '24255', '24257', '24260', '24300', '24305', '24320', '24325', '24369', '24383', '24424', '24490', '24515', '24546', '24549', '24555', '24573', '24636', '24766', '24768', '24809', '24850', '24861', '24867', '24888', '24908', '24915', '24924', '24960', '24962', '25036', '25039', '25097', '25149', '25227', '25276', '25310', '25335', '25364', '25370', '25378', '25382', '25394', '25403', '25434', '25435', '25503', '25526', '25575', '25576', '25615', '25616', '25680', '25694', '25704', '25707', '25715', '25734', '25752', '25755', '25756', '25839', '25850', '25863', '25949', '25981', '26082', '26091', '26092', '26101', '26118', '26124', '26211', '26296', '26336', '26350', '26351', '26352', '26353', '26362', '26363', '26370', '26375', '26376', '26394', '26400', '26415', '26418', '26432', '26435', '26436', '26456', '26476', '26488', '26489', '26504', '26510', '26511', '26515', '26531', '26562', '26566', '26567', '26571', '26607', '26608', '26609', '26614', '26634', '26648', '26680', '26681', '26685', '26710', '26712', '26720', '26721', '26749', '26757', '26758', '26759', '26762', '26763', '26764', '26765', '26766', '26780', '26804', '26806', '26807', '26808', '26809', '26810', '26811', '26812', '26813', '26814', '26815', '26816', '26817', '26818', '26819', '26820', '26821', '26822', '26823', '26824', '26825', '26826', '26827', '26859', '26863', '26868', '26891', '26894', '26943', '27018', '27029', '27031', '27033', '27036', '27037', '27038', '27040', '27041', '27042', '27043', '27044', '27046', '27047', '27131', '27132', '27133', '27134', '27144', '27186', '27191', '27215', '27267', '27269', '27278', '27281', '27282', '27283', '27296', '27303', '27304', '27308', '27313', '27314', '27345', '27379', '27380', '27381', '27427', '27451', '27566', '27573', '27574', '27610', '27617', '27619', '27623', '27625', '27626', '27627', '27641', '27644', '27694', '27772', '27773', '27774', '27782', '27806', '27807', '27808', '27809', '27828', '27831', '27837', '27840', '27843', '27844', '27854', '27855', '27876', '27878', '27904', '27905', '27931', '27936', '27970', '27976', '28078', '28080', '28081', '28082', '28113', '28114', '28124', '28150', '28151', '28176', '28200', '28210', '28211', '28217', '28227', '28230', '28249', '28254', '28258', '28286', '28332', '28358', '28370', '28373', '28397', '28412', '28427', '28429', '28485', '28519', '28523', '28534', '28559', '28568', '28578', '28602', '28603', '28604', '28605', '28622', '28623', '28624', '28702', '28718', '28787', '28788', '28789', '28827', '28828', '28829', '28891', '28930', '28937', '28977', '29012', '29013', '29049', '29052', '29094', '29114', '29164', '29169', '29175', '29179', '29191', '29251', '29279', '29282', '29302', '29332', '29397', '29398', '29476', '29523', '29536', '29537', '29538', '29539', '29549', '29565', '29613', '29651', '29755', '29758', '29772', '29775', '29781', '29783', '29800', '29802', '29805', '29823', '29825', '29826', '29843', '29844', '29857', '29887', '29900', '29918', '29922', '29923', '29924', '29925', '29933', '29937', '29943', '29950', '29956', '29985', '29986', '30008', '30009', '30013', '30021', '30046', '30067', '30117', '30144', '30195', '30196', '30211', '30226', '30268', '30317', '30363', '30395', '30423', '30484', '30512', '30513', '30514', '30515', '30516', '30517', '30518', '30519', '30520', '30521', '30553', '30582', '30587', '30602', '30603', '30604', '30633', '30638', '30654', '30696', '30702', '30703', '30779', '30781', '30782', '30783', '30784', '30785', '30786', '30787', '30788', '30789', '30831', '30904', '30906', '30912', '30913', '30915', '30917', '30918', '30922', '30923', '30930', '30949', '30977', '30978', '30994', '31022', '31032', '31042', '31054', '31069', '31070', '31117', '31137', '31154', '31176', '31260', '31295', '31432', '31440', '31441', '31474', '31475', '31505', '31559', '31600', '31628', '31634', '31636', '31720', '31725', '31746', '31800', '31805', '31817', '31846', '31863', '31881', '31882', '31883', '31904', '31918', '31931', '31939', '31944', '31947', '31952', '31983', '31988', '32056', '32065', '32113', '32138', '32194', '32244', '32297', '32298', '32348', '32349', '32509', '32621', '32640', '32643', '32798', '32808', '32854', '32948', '32952', '32963', '33031', '33032', '33033', '33069', '33083', '33106', '33132', '33157', '33184', '33205', '33229', '33250', '33258', '33315', '33316', '33339', '33378', '33388', '33408', '33414', '33422', '33440', '33454', '33462', '33463', '33489', '33518', '33534', '33539', '33580', '33590', '33595', '33633', '33655', '33675', '33692', '33693', '33699', '33711', '33734', '33738', '33744', '33745', '33746', '33796', '33807', '33808', '33832', '33840', '33851', '33856', '33861', '33869', '33875', '33881', '33882', '33893', '33894', '33895', '34008', '34119', '34204', '34212', '34231', '34264', '34285', '34301', '34332', '34342', '34349', '34351', '34354', '34363', '34364', '34382', '34385', '34389', '34422', '34432', '34435', '34445', '34451', '34456', '34469', '34496', '34516', '34528', '34547', '34559', '34560', '34561', '34564', '34608', '34620', '34633', '34648', '34692', '34693', '34708', '34715', '34783', '34800', '34803', '34841', '34872', '34918', '34976', '34991', '34998', '35005', '35010', '35019', '35028', '35052', '35054', '35064', '35089', '35100', '35103', '35124', '35129', '35134', '35150', '35151', '35163', '35166', '35200', '35209', '35223', '35235', '35262', '35267', '35285', '35286', '35309', '35315', '35319', '35343', '35376', '35390', '35404', '35406', '35444', '35445', '35446', '35476', '35482', '35498', '35525', '35568', '35609', '35643', '35657', '35718', '35732', '35754', '35766', '35798', '35814', '35825', '35827', '35840', '35854', '35871', '35876', '35878', '35880', '35897', '35908', '35919', '35929', '35938', '35951', '35955', '35969', '35971', '35979', '35986', '35988', '36001', '36011', '36025', '36058', '36086', '36207', '36260', '36315', '36316', '36331', '36334', '36352', '36357', '36369', '36371', '36380', '36394', '36413', '36416', '36436', '36437', '36447', '36454', '36455', '36460', '36468', '36469', '36477', '36482', '36510', '36528', '36596', '36630', '36635', '36647', '36676', '36704', '36706', '36708', '36729', '36738', '36742', '36777', '36780', '36786', '36806', '36807', '36812', '36814', '36826', '36868', '36894', '36900', '36909', '36910', '36911', '36938', '36941', '36947', '36972', '36978', '36987', '37011', '37040', '37050', '37051', '37052', '37053', '37075', '37076', '37084', '37096', '37133', '37135', '37138', '37183', '37184', '37201', '37248', '37273', '37305', '37306', '37319', '37384', '37401', '37417', '37428', '37468', '37473', '37491', '37506', '37524', '37526', '37534', '37555', '37567', '37569', '37577', '37601', '37604', '37608', '37616', '37617', '37630', '37634', '37654', '37678', '37733', '37760', '37762', '37769', '37771', '37798', '37799', '37805', '37836', '37851', '37874', '37886', '37896', '37914', '37941', '37951', '37971', '37987', '37989', '37990', '38002', '38031', '38034', '38042', '38057', '38058', '38059', '38074', '38089', '38118', '38122', '38150', '38159', '38166', '38210', '38225', '38242', '38243', '38244', '38256', '38257', '38258', '38271', '38313', '38316', '38320', '38335', '38358', '38361', '38362', '38385', '38392', '38400', '38435', '38442', '38493', '38499', '38527', '38543', '38548', '38576', '38581', '38618', '38639', '38643', '38660', '38674', '38696', '38704', '38705', '38706', '38712', '38725', '38729', '38734', '38736', '38737', '38756', '38760', '38778', '38797', '38842', '38849', '38867', '38868', '38883', '38912', '38913', '38925', '38935', '38971', '38974', '38987', '38996', '39016', '39024', '39031', '39071', '39101', '39117', '39153', '39156', '39165', '39171', '39173', '39183', '39201', '39220', '39240', '39241', '39242', '39256', '39311', '39314', '39320', '39325', '39327', '39328', '39331', '39335', '39360', '39363', '39405', '39410', '39429', '39436', '39449', '39481', '39492', '39512', '39513', '39533', '39555', '39589', '39637', '39654', '39679', '39687', '39694', '39719', '39738', '39739', '39765', '39774', '39811', '39825', '39835', '39836', '39837', '39876', '39877', '39880', '39884', '39885', '39889', '39901', '39921', '39953', '39976', '40011', '40025', '40049', '40052', '40085', '40086', '40095', '40107', '40172', '40189', '40193', '40194', '40195', '40213', '40239', '40247', '40248', '40272', '40279', '40308', '40374', '40399', '40413', '40422', '40456', '40489', '40496', '40516', '40530', '40551', '40625', '40694', '40695', '40707', '40720', '40750', '40763', '40768', '40778', '40855', '40877', '40902', '40916', '41038', '41039', '41054', '41056', '41065', '41066', '41079', '41080', '41087', '41099', '41100', '41101', '41112', '41113', '41114', '41116', '41121', '41123', '41124', '41147', '41155', '41208', '41211', '41226', '41251', '41307', '41322', '41336', '41385', '41413', '41428', '41486', '41577', '41578', '41614', '41644', '41692', '41731', '41738', '41769', '41815', '41843', '41872', '41903', '41943', '41965', '41967', '41968', '41969', '41970', '41984', '41986', '41991', '41992', '42021', '42036', '42063', '42064', '42088', '42126', '42131', '42141', '42151', '42177', '42186', '42192', '42211', '42256', '42292', '42297', '42298', '42300', '42310', '42377', '42421', '42432', '42463', '42464', '42472', '42487', '42497', '42524', '42525', '42554', '42561', '42563', '42586', '42590', '42592', '42594', '42615', '42624', '42627', '42635', '42636', '42637', '42648', '42659', '42662', '42663', '42675', '42694', '42695', '42711', '42743', '42744', '42765', '42798', '42836', '42852', '42896', '42924', '42939', '42986', '43003', '43004', '43047', '43196', '43233', '43239', '43277', '43291', '43294', '43306', '43307', '43308', '43309', '43310', '43311', '43312', '43313', '43315', '43321', '43389', '43408', '43436', '43441', '43475', '43476', '43501', '43535', '43554', '43561', '43567', '43632', '43652', '43676', '43698', '43712', '43718', '43734', '43748', '43752', '43760', '43761', '43767', '43772', '43782', '43784', '43787', '43822', '43823', '43839', '43848', '43849', '43851', '43871', '43889', '43901', '43923', '43924', '43926', '43956', '43960', '43964', '43980', '44017', '44023', '44054', '44068', '44070', '44095', '44098', '44139', '44141', '44142', '44156', '44160', '44161', '44162', '44180', '44181', '44187', '44198', '44199', '44232', '44236', '44242', '44244', '44255', '44260', '44277', '44285', '44300', '44301', '44310', '44323', '44346', '44354', '44356', '44357', '44359', '44373', '44390', '44402', '44403', '44407', '44453', '44467', '44468', '44473', '44483', '44488', '44504', '44519', '44543', '44589', '44613', '44617', '44620', '44634', '44664', '44675', '44676', '44689', '44696', '44697', '44710', '44713', '44715', '44723', '44749', '44756', '44762', '44807', '44812', '44861', '44874', '44877', '44906', '44911', '44931', '44958', '44960', '45012', '45013', '45014', '45015', '45031', '45036', '45058', '45060', '45119', '45121', '45150', '45176', '45207', '45211', '45212', '45213', '45312', '45323', '45374', '45411', '45428', '45437', '45458', '45468', '45501', '45513', '45533', '45550', '45575', '45590', '45607', '45633', '45650', '45655', '45679', '45692', '45704', '45715', '45722', '45753', '45787', '45855', '45864', '45879', '45886', '45892', '45911', '45935', '45947', '45953', '45967', '45969', '45970', '45979', '46003', '46022', '46033', '46065', '46111', '46137', '46142', '46183', '46224', '46245', '46253', '46282', '46285', '46314', '46350', '46357', '46364', '46373', '46387', '46444', '46447', '46461', '46469', '46470', '46490', '46525', '46527', '46541', '46557', '46598', '46604', '46625', '46626', '46627', '46628', '46631', '46632', '46640', '46646', '46673', '46683', '46687', '46718', '46721', '46747', '46765', '46828', '46832', '46837', '46842', '46870', '46876', '46891', '46902', '46907', '46916', '46931', '46932', '46991', '47009', '47042', '47062', '47124', '47131', '47133', '47164', '47171', '47188', '47196', '47207', '47255', '47301', '47321', '47324', '47329', '47345', '47350', '47352', '47360', '47395', '47404', '47414', '47459', '47468', '47469', '47473', '47477', '47479', '47514', '47522', '47552', '47623', '47632', '47645', '47680', '47712', '47717', '47720', '47740', '47743', '47756', '47783', '47802', '47846', '47903', '47918', '47919', '47963', '47969', '47982', '47984', '47986', '47987', '48004', '48006', '48011', '48061', '48082', '48135', '48212', '48257', '48259', '48279', '48282', '48359', '48368', '48383', '48401', '48421', '48432', '48462', '48477', '48518', '48519', '48520', '48529', '48581', '48590', '48650', '48683', '48684', '48687', '48781', '48830', '48852', '48855', '48878', '48881', '48933', '49004', '49005', '49168', '49293', '49312', '49355', '49397', '49398', '49399', '49408', '49409', '49441', '49446', '49482', '49502', '49542', '49573', '49574', '49575', '49586', '49619', '49712', '49725', '49743', '49761', '49773', '49783', '49813', '49855', '49878', '49887', '49943', '49977', '50007', '50024', '50069', '50083', '50144', '50154', '50173', '50208', '50211', '50267', '50278', '50340', '50356', '50398', '50435', '50447', '50485', '50513', '50521', '50580', '50593', '50594', '50605', '50633', '50655', '50664', '50708', '50718', '50725', '50743', '50786', '50805', '50838', '50926', '50930', '50974', '50997', '51005', '51012', '51013', '51014', '51023', '51083', '51084', '51120', '51144', '51149', '51156', '51162', '51178', '51179', '51214', '51225', '51227', '51236', '51237', '51266', '51269', '51312', '51313', '51338', '51364', '51372', '51373', '51381', '51405', '51423', '51505', '51515', '51516', '51591', '51606', '51612', '51631', '51632', '51659', '51666', '51703', '51709', '51725', '51787', '51802', '51826', '51837', '51876', '51977', '52006', '52011', '52065', '52123', '52140', '52145', '52282', '52288', '52331', '52332', '52376', '52379', '52380', '52428', '52431', '52443', '52487', '52488', '52489', '52511', '52520', '52565', '52585', '52592', '52611', '52629', '52635', '52669', '52707', '52753', '52797', '52829', '52831', '52843', '52880', '52893', '52919', '52927', '52929', '52933', '52950', '52990', '53082', '53147', '53183', '53247', '53279', '53284', '53309', '53310', '53321', '53331', '53374', '53399', '53439', '53497', '53502', '53503', '53523', '53536', '53540', '53564', '53595', '53599', '53634', '53640', '53694', '53710', '53722', '53737', '53749', '53761', '53779', '53805', '53848', '53950', '54002', '54020', '54035', '54087', '54089', '54182', '54183', '54202', '54231', '54308', '54311', '54339', '54387', '54393', '54397', '54419', '54421', '54456', '54467', '54482', '54528', '54551', '54600', '54659', '54723', '54819', '54835', '54873', '54952', '54963', '54972', '54975', '54983', '54997', '55023', '55028', '55072', '55111', '55135', '55136', '55167', '55175', '55233', '55259', '55265', '55367', '55430', '55446', '55456', '55483', '55501', '55517', '55533', '55554', '55569', '55637', '55638', '55639', '55640', '55659', '55700', '55733', '55766', '55836', '55855', '55860', '55869', '55879', '56054', '56072', '56149', '56173', '56212', '56265', '56309', '56327', '56374', '56390', '56393', '56461', '56473', '56474', '56511', '56515', '56518', '56523', '56545', '56558', '56622', '56645', '56646', '56668', '56708', '56801', '56808', '56820', '56859', '56892', '56918', '56919', '56952', '56990', '56997', '57019', '57023', '57075', '57182', '57193', '57204', '57262', '57270', '57360', '57373', '57419', '57420', '57425', '57429', '57430', '57449', '57454', '57461', '57506', '57525', '57547', '57567', '57607', '57655', '57656', '57687', '57712', '57719', '57743', '57745', '57746', '57747', '57762', '57766', '57775', '57788', '57789', '57790', '57824', '57832', '57839', '57856', '57865', '57866', '57870', '57878', '57879', '57880', '57892', '57919', '57964', '57992', '58063', '58069', '58084', '58088', '58090', '58149', '58154', '58170', '58211', '58244', '58251', '58254', '58259', '58260', '58265', '58290', '58299', '58309', '58317', '58362', '58366', '58427', '58467', '58476', '58501', '58532', '58661', '58698', '58706', '58801', '58818', '58828', '59073', '59140', '59151', '59327', '59365', '59431', '59442', '59525', '59719', '59781', '59859', '59889', '59905', '59926', '59992', '59996', '60030', '60033', '60052', '60079', '60080', '60106', '60199', '60204', '60223', '60323', '60347', '60354', '60355', '60366', '60383', '60394', '60417', '60450', '60551', '60571', '60589', '60594', '60610', '60616', '60618', '60635', '60656', '60665', '60666', '60680', '60711', '60720', '60738', '60739', '60746', '60752', '60798', '60806', '60810', '60812', '60821', '60827', '60828', '60841', '60847', '60850', '60863', '60882', '60891', '60892', '60896', '60916', '60918', '60924', '60986', '60998', '61008', '61035', '61039', '61059', '61075', '61088', '61091', '61108', '61134', '61160', '61162', '61181', '61218', '61227', '61239', '61248', '61258', '61270', '61274', '61311', '61318', '61351', '61357', '61373', '61383', '61390', '61404', '61408', '61411', '61418', '61458', '61460', '61472', '61485', '61489', '61500', '61524', '61527', '61528', '61538', '61565', '61571', '61602', '61627', '61636', '61639', '61642', '61664', '61666', '61675', '61691', '61697', '61718', '61733', '61738', '61745', '61747', '61768', '61772', '61789', '61793', '61816', '61856', '61868', '61876', '61905', '61920', '61961', '61965', '61970', '61980', '61994', '62007', '62013', '62021', '62024', '62051', '62092', '62100', '62106', '62108', '62114', '62120', '62147', '62179', '62196', '62207', '62215', '62216', '62243', '62248', '62272', '62281', '62289', '62290', '62318', '62337', '62340', '62368', '62393', '62399', '62404', '62405', '62406', '62508', '62581', '62615', '62679', '62753', '62787', '62812', '62874', '62922', '62933', '62960', '63063', '63129', '63137', '63141', '63144', '63147', '63151', '63161', '63167', '63174', '63179', '63185', '63187', '63193', '63201', '63206', '63220', '63222', '63244', '63248', '63253', '63259', '63260', '63267', '63271', '63284', '63285', '63296', '63303', '63305', '63319', '63329', '63341', '63349', '63380', '63409', '63435', '63470', '63478', '63496', '63543', '63575', '63576', '63634', '63646', '63647', '63661', '63674', '63679', '63712', '63714', '63726', '63734', '63762', '63768', '63773', '63777', '63794', '63804', '63846', '63881', '63887', '63904', '63905', '63906', '63914', '63915', '63920', '63937', '63951', '63979', '64008', '64023', '64034', '64065', '64066', '64067', '64084', '64086', '64091', '64113', '64116', '64119', '64129', '64130', '64144', '64145', '64153', '64165', '64168', '64202', '64213', '64220', '64222', '64232', '64274', '64289', '64290', '64305', '64322', '64369', '64371', '64423', '64427', '64428', '64453', '64582', '64590', '64598', '64614', '64645', '64663', '64674', '64706', '64719', '64721', '64729', '64737', '64760', '64776', '64787', '64792', '64798', '64805', '64811', '64821', '64830', '64838', '64850', '64853', '64862', '64867', '64886', '64902', '64909', '64912', '64913', '64920', '64935', '64939', '64949', '64952', '64966', '64976', '64983', '64990', '65025', '65031', '65052', '65059', '65082', '65092', '65096', '65105', '65109', '65111', '65118', '65150', '65178', '65183', '65187', '65213', '65219', '65224', '65251', '65263', '65266', '65273', '65275', '65301', '65339', '65349', '65366', '65402', '65403', '65405', '65419', '65420', '65421', '65434', '65448', '65449', '65457', '65491', '65494', '65522', '65530', '65546', '65578', '65595', '65684', '65751', '65845', '65878', '65969', '65989', '65990', '66035', '66117', '66261', '66430', '66480', '66483', '66505', '66557', '66645', '66652', '66664', '66665', '66672', '66674', '66681', '66682', '66695', '66704', '66709', '66715', '66725', '66761', '66771', '66776', '66781', '66787', '66793', '66795', '66805', '66810', '66817', '66827', '66834', '66839', '66845', '66852', '66853', '66878', '66883', '66894', '66897', '66912', '66924', '66927', '66929', '66965', '66978', '66980', '66985', '66995', '67010', '67024', '67052', '67074', '67080', '67086', '67094', '67096', '67099', '67102', '67119', '67120', '67129', '67136', '67158', '67163', '67205', '67233', '67243', '67245', '67252', '67253', '67260', '67264', '67273', '67275', '67281', '67283', '67320', '67327', '67382', '67387', '67400', '67402', '67427', '67469', '67620', '67719', '67741', '67779', '67848', '67853', '67860', '67863', '67867', '67884', '67889', '67910', '67924', '67927', '67930', '67932', '67933', '67940', '67963', '67999', '68010', '68036', '68086', '68138', '68199', '68245', '68264', '68265', '68271', '68281', '68286', '68297', '68298', '68303', '68307', '68312', '68314', '68327', '68355', '68356', '68476', '68485', '68487', '68501', '68510', '68516', '68528', '68560', '68581', '68588', '68591', '68595', '68602', '68603', '68606', '68618', '68621', '68622', '68632', '68634', '68636', '68637', '68675', '68701', '68710', '68719', '68728', '68764', '68818', '68821', '68839', '68848', '68856', '68865', '68871', '68872', '68885', '69007', '69031', '69059', '69089', '69098', '69165', '69182', '69189', '69194', '69301', '69437', '69536', '69540', '69541', '69561', '69568', '69583', '69597', '69605', '69618', '69621', '69626', '69631', '69743', '69770', '69794', '69802', '69848', '69860', '69863', '69872', '69887', '69915', '69917', '69918', '69936', '69982', '70042', '70061', '70065', '70071', '70081', '70084', '70088', '70094', '70095', '70102', '70117', '70126', '70144', '70161', '70162', '70186', '70201', '70202', '70241', '70260', '70268', '70299', '70312', '70319', '70327', '70340', '70351', '70354', '70363', '70369', '70396', '70422', '70423', '70504', '70505', '70534', '70538', '70612', '70650', '70661', '70676', '70746', '70747', '70752', '70753', '70754', '70762', '70778', '70789', '70801', '70818', '70821', '70825', '70846', '70878', '70891', '70900', '70904', '70917', '70954', '70956', '70989', '70995', '71022', '71023', '71033', '71053', '71054', '71059', '71062', '71093', '71110', '71123', '71143', '71172', '71199', '71204', '71206', '71208', '71225', '71227', '71236', '71251', '71272', '71287', '71293', '71294', '71296', '71298', '71299', '71300', '71331', '71341', '71420', '71429', '71445', '71455', '71457', '71473', '71487', '71499', '71504', '71510', '71553', '71555', '71588', '71635', '71658', '71686', '71715', '71731', '71737', '71738', '71764', '71773', '71778', '71820', '71831', '71832', '71878', '71901', '71911', '71926', '71932', '72023', '72024', '72034', '72035', '72036', '72055', '72071', '72091', '72127', '72145', '72151', '72160', '72169', '72189', '72194', '72205', '72244', '72251', '72252', '72256', '72260', '72366', '72370', '72385', '72393', '72414', '72486', '72534', '72605', '72618', '72619', '72620', '72630', '72637', '72641', '72643', '72671', '72702', '72712', '72726', '72759', '72773', '72784', '72798', '72799', '72802', '72804', '72805', '72808', '72819', '72836', '72844', '72867', '72872', '72885', '72889', '72905', '72912', '72931', '72939', '72954', '72973', '72975', '72978', '72982', '72984', '72990', '72992', '73002', '73007', '73010', '73018', '73022', '73024', '73033', '73039', '73058', '73077', '73097', '73103', '73118', '73164', '73168', '73191', '73199', '73206', '73210', '73232', '73235', '73252', '73279', '73282', '73291', '73294', '73329', '73333', '73348', '73349', '73370', '73384', '73406', '73436', '73443', '73454', '73478', '73481', '73526', '73552', '73562', '73571', '73632', '73633', '73653', '73689', '73690', '73732', '73768', '73831', '73832', '73847', '73864', '73865', '73874', '73884', '73889', '73894', '73904', '73912', '73918', '73928', '73946', '73948', '73967', '73977', '73985', '73998', '73999', '74003', '74034', '74035', '74040', '74052', '74061', '74080', '74081', '74090', '74091', '74092', '74093', '74094', '74106', '74126', '74128', '74131', '74141', '74143', '74157', '74164', '74208', '74209', '74216', '74225', '74228', '74285', '74286', '74290', '74327', '74329', '74351', '74364', '74384', '74398', '74404', '74411', '74412', '74420', '74426', '74429', '74434', '74446', '74448', '74453', '74455', '74459', '74465', '74481', '74487', '74488', '74495', '74499', '74506', '74509', '74513', '74562', '74569', '74580', '74599', '74628']
|
| 26 |
+
# Check if the file exists
|
| 27 |
+
if os.path.exists(file_path):
|
| 28 |
+
# Load and return the existing text file
|
| 29 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 30 |
+
return file.read()
|
| 31 |
+
|
| 32 |
+
# Initialize an empty string to hold all the text
|
| 33 |
+
combined_text = ""
|
| 34 |
+
|
| 35 |
+
# Fetch and append each text
|
| 36 |
+
print("total books", len(french_text_ids))
|
| 37 |
+
i = 0
|
| 38 |
+
for text_id in french_text_ids:
|
| 39 |
+
i += 1
|
| 40 |
+
if i % 15 == 0:
|
| 41 |
+
print(i)
|
| 42 |
+
text = fetch_gutenberg_text(text_id)
|
| 43 |
+
if text:
|
| 44 |
+
combined_text += text + "\n\n" # Separate texts with double newlines
|
| 45 |
+
|
| 46 |
+
# Clean the text before returning
|
| 47 |
+
full_text = clean_text(combined_text)
|
| 48 |
+
|
| 49 |
+
# Save the combined text to a file
|
| 50 |
+
with open(file_path, 'w', encoding='utf-8') as file:
|
| 51 |
+
file.write(full_text)
|
| 52 |
+
|
| 53 |
+
return full_text
|
infer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from preprocess_dataset import preprocess_text
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch.nn import Transformer
|
| 7 |
+
import math
|
| 8 |
+
import bitsandbytes as bnb
|
| 9 |
+
from invariants import get_data_pairs
|
| 10 |
+
from french_dataset import get_full_dataset
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
MAX_SEQUENCE_LENGTH = 512
|
| 15 |
+
training_pairs = get_data_pairs(get_full_dataset())
|
| 16 |
+
PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 256, 699, 257
|
| 17 |
+
|
| 18 |
+
class CipherDataset(Dataset):
|
| 19 |
+
def __init__(self, pairs):
|
| 20 |
+
self.encodings = []
|
| 21 |
+
self.symbol_indices = []
|
| 22 |
+
bulk_dataset = pairs[:3000]
|
| 23 |
+
for entry in bulk_dataset:
|
| 24 |
+
|
| 25 |
+
if len(entry[0]) < MAX_SEQUENCE_LENGTH:
|
| 26 |
+
self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
|
| 27 |
+
elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
|
| 28 |
+
self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
|
| 29 |
+
else:
|
| 30 |
+
self.encodings.append([BOS_IDX_SRC] + entry[0][: - 1])
|
| 31 |
+
|
| 32 |
+
if len(entry[1]) < MAX_SEQUENCE_LENGTH:
|
| 33 |
+
self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
|
| 34 |
+
elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
|
| 35 |
+
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
|
| 36 |
+
else:
|
| 37 |
+
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][: - 1])
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.encodings)
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, idx):
|
| 42 |
+
return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])
|
| 43 |
+
|
| 44 |
+
class PositionalEncoding(nn.Module):
|
| 45 |
+
def __init__(self, emb_size: int, dropout: float, maxlen: int = MAX_SEQUENCE_LENGTH):
|
| 46 |
+
super(PositionalEncoding, self).__init__()
|
| 47 |
+
den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
|
| 48 |
+
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
|
| 49 |
+
pos_embedding = torch.zeros((maxlen, emb_size))
|
| 50 |
+
pos_embedding[:, 0::2] = torch.sin(pos * den)
|
| 51 |
+
pos_embedding[:, 1::2] = torch.cos(pos * den)
|
| 52 |
+
pos_embedding = pos_embedding.unsqueeze(0).transpose(0, 1)
|
| 53 |
+
|
| 54 |
+
self.dropout = nn.Dropout(dropout)
|
| 55 |
+
self.register_buffer('pos_embedding', pos_embedding)
|
| 56 |
+
|
| 57 |
+
def forward(self, token_embedding: Tensor):
|
| 58 |
+
return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
|
| 59 |
+
|
| 60 |
+
class TokenEmbedding(nn.Module):
|
| 61 |
+
def __init__(self, vocab_size: int, emb_size):
|
| 62 |
+
super(TokenEmbedding, self).__init__()
|
| 63 |
+
self.embedding = nn.Embedding(vocab_size, emb_size)
|
| 64 |
+
self.pos_encoder = PositionalEncoding(emb_size, 0.2)
|
| 65 |
+
self.emb_size = emb_size
|
| 66 |
+
|
| 67 |
+
def forward(self, tokens: Tensor):
|
| 68 |
+
return self.pos_encoder(self.embedding(tokens.long()) * math.sqrt(self.emb_size))
|
| 69 |
+
|
| 70 |
+
class Seq2SeqTransformer(nn.Module):
|
| 71 |
+
def __init__(self,
|
| 72 |
+
num_encoder_layers: int,
|
| 73 |
+
num_decoder_layers: int,
|
| 74 |
+
emb_size: int,
|
| 75 |
+
nhead: int,
|
| 76 |
+
src_vocab_size: int,
|
| 77 |
+
tgt_vocab_size: int,
|
| 78 |
+
dim_feedforward: int = 512,
|
| 79 |
+
dropout: float = 0.1):
|
| 80 |
+
super(Seq2SeqTransformer, self).__init__()
|
| 81 |
+
self.transformer = Transformer(d_model=emb_size,
|
| 82 |
+
nhead=nhead,
|
| 83 |
+
num_encoder_layers=num_encoder_layers,
|
| 84 |
+
num_decoder_layers=num_decoder_layers,
|
| 85 |
+
dim_feedforward=dim_feedforward,
|
| 86 |
+
dropout=dropout)
|
| 87 |
+
self.generator = nn.Linear(emb_size, tgt_vocab_size)
|
| 88 |
+
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
|
| 89 |
+
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
|
| 90 |
+
|
| 91 |
+
def forward(self,
|
| 92 |
+
src: Tensor,
|
| 93 |
+
trg: Tensor,
|
| 94 |
+
src_mask: Tensor,
|
| 95 |
+
tgt_mask: Tensor,
|
| 96 |
+
src_padding_mask: Tensor,
|
| 97 |
+
tgt_padding_mask: Tensor,
|
| 98 |
+
memory_key_padding_mask: Tensor):
|
| 99 |
+
src_emb = self.src_tok_emb(src)
|
| 100 |
+
tgt_emb = self.tgt_tok_emb(trg)
|
| 101 |
+
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
|
| 102 |
+
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
|
| 103 |
+
return self.generator(outs)
|
| 104 |
+
|
| 105 |
+
def encode(self, src: Tensor, src_mask: Tensor):
|
| 106 |
+
return self.transformer.encoder(self.src_tok_emb(src), src_mask)
|
| 107 |
+
|
| 108 |
+
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
|
| 109 |
+
return self.transformer.decoder(self.tgt_tok_emb(tgt), memory, tgt_mask)
|
| 110 |
+
|
| 111 |
+
def generate_square_subsequent_mask(sz):
|
| 112 |
+
mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
|
| 113 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
| 114 |
+
return mask
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def create_mask(src, tgt):
|
| 118 |
+
src_seq_len = src.shape[0]
|
| 119 |
+
tgt_seq_len = tgt.shape[0]
|
| 120 |
+
|
| 121 |
+
tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
|
| 122 |
+
src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
|
| 123 |
+
|
| 124 |
+
src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
|
| 125 |
+
tgt_padding_mask = (tgt == PAD_IDX_TGT).transpose(0, 1)
|
| 126 |
+
return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
|
| 127 |
+
|
| 128 |
+
torch.manual_seed(22330)
|
| 129 |
+
|
| 130 |
+
GRAD_CLIP = 1.0
|
| 131 |
+
|
| 132 |
+
SRC_VOCAB_SIZE = 700
|
| 133 |
+
TGT_VOCAB_SIZE = 258
|
| 134 |
+
EMB_SIZE = 512
|
| 135 |
+
NHEAD = 16
|
| 136 |
+
FFN_HID_DIM = 768
|
| 137 |
+
BATCH_SIZE = 32
|
| 138 |
+
NUM_ENCODER_LAYERS = 4
|
| 139 |
+
NUM_DECODER_LAYERS = 4
|
| 140 |
+
GRAD_CLIP = 1.0
|
| 141 |
+
DROPOUT = 0.3 # Increased Dropout
|
| 142 |
+
TOTAL_STEPS = 400
|
| 143 |
+
|
| 144 |
+
model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
|
| 145 |
+
NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM,
|
| 146 |
+
dropout=DROPOUT)
|
| 147 |
+
model.load_state_dict(torch.load("cipher.pth", map_location=DEVICE))
|
| 148 |
+
model.to(DEVICE)
|
| 149 |
+
model.eval()
|
| 150 |
+
# function to generate output sequence using greedy algorithm
|
| 151 |
+
def greedy_decode(model, src, src_mask, max_len, start_symbol):
|
| 152 |
+
src = src.to(DEVICE)
|
| 153 |
+
src_mask = src_mask.to(DEVICE)
|
| 154 |
+
|
| 155 |
+
memory = model.encode(src, src_mask)
|
| 156 |
+
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
|
| 157 |
+
for i in range(max_len-1):
|
| 158 |
+
memory = memory.to(DEVICE)
|
| 159 |
+
tgt_mask = (generate_square_subsequent_mask(ys.size(0))
|
| 160 |
+
.type(torch.bool)).to(DEVICE)
|
| 161 |
+
out = model.decode(ys, memory, tgt_mask)
|
| 162 |
+
out = out.transpose(0, 1)
|
| 163 |
+
prob = model.generator(out[:, -1])
|
| 164 |
+
_, next_word = torch.max(prob, dim=1)
|
| 165 |
+
next_word = next_word.item()
|
| 166 |
+
|
| 167 |
+
ys = torch.cat([ys,
|
| 168 |
+
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
|
| 169 |
+
return ys
|
| 170 |
+
|
| 171 |
+
from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
|
| 172 |
+
from preprocess_dataset import get_frequency_ranks, get_proximity_array
|
| 173 |
+
# actual function to translate input sentence into target language
|
| 174 |
+
def translate(model: torch.nn.Module, src_sentence: str):
|
| 175 |
+
model.eval()
|
| 176 |
+
symbols = load_or_save_symbols([])
|
| 177 |
+
rule = substitution_cipher(symbols, 1337)
|
| 178 |
+
encodings, _ = encode_text_with_indices(rule, symbols, src_sentence)
|
| 179 |
+
encodings = torch.tensor(encodings[:20]).view(-1, 1)
|
| 180 |
+
print(encodings.shape)
|
| 181 |
+
num_tokens = encodings.shape[0]
|
| 182 |
+
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
|
| 183 |
+
tgt_tokens = greedy_decode(
|
| 184 |
+
model, encodings, src_mask, max_len=MAX_SEQUENCE_LENGTH, start_symbol=256).flatten()
|
| 185 |
+
tokens = tgt_tokens.cpu().numpy()[1:]
|
| 186 |
+
for i, x in enumerate(tokens):
|
| 187 |
+
if x > 255:
|
| 188 |
+
tokens[i] = 0
|
| 189 |
+
print(tokens)
|
| 190 |
+
return (''.join([symbols[x] for x in tokens]))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
print(translate(model, """Bien sûr ! Voici un texte plus long en français: La beauté de la nature réside dans sa diversité et sa capacité à émerveiller à chaque saison. Chaque paysage, qu’il s’agisse de montagnes imposantes, de forêts mystérieuses ou de rivières sinueuses, possède une âme et une histoire à raconter. Lorsqu’on s’aventure au cœur d’une forêt, l’air frais, imprégné des parfums de bois et de terre humide, nous invite à ralentir et à savourer l’instant. Les feuilles qui bruissent sous nos pas, les oiseaux qui chantent et la lumière qui filtre à travers les arbres créent une atmosphère presque magique, propice à la contemplation. Au printemps, la nature se réveille lentement, offrant un spectacle de couleurs éclatantes : des fleurs qui éclorent, des bourgeons qui germent, des champs qui se parent de verts éclatants. L'été, quant à lui, emporte tout dans un tourbillon de chaleur, de lumière et de vie. Les journées longues et ensoleillées sont idéales pour profiter des bienfaits du plein air, que ce soit à la mer, à la montagne ou simplement dans le jardin. L’automne, avec ses nuances orangées et dorées, est une invitation à la réflexion et à la tranquillité. Les feuilles tombent en tourbillonnant, créant des tapis colorés qui habillent la terre. Puis vient l’hiver, avec son froid piquant et la neige qui transforme le monde en un paysage féerique, silencieux et apaisant. Au-delà de la beauté visuelle, la nature nous enseigne aussi l’humilité et la résilience. Elle nous rappelle que tout est en perpétuel mouvement et que chaque cycle a sa raison d’être. Nous, humains, sommes un petit maillon dans cet écosystème complexe, et il est de notre responsabilité de préserver ce précieux équilibre. En prenant soin de notre environnement, nous garantissons non seulement la survie des espèces, mais aussi notre propre bien-être. La nature, avec sa sagesse silencieuse, continue de nous offrir des leçons de vie, chaque jour, sous nos yeux émerveillés."""))
|
invariants.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cipher_8bit import *
|
| 2 |
+
from french_dataset import get_full_dataset
|
| 3 |
+
import json
|
| 4 |
+
import pickle
|
| 5 |
+
|
| 6 |
+
def get_pattern_ranks(pattern_frequency_dict):
|
| 7 |
+
sorted_items = sorted(pattern_frequency_dict.items(), key=lambda x: x[1], reverse=True)
|
| 8 |
+
|
| 9 |
+
# Initialize a new dictionary for ranks
|
| 10 |
+
ranked_dict = {}
|
| 11 |
+
|
| 12 |
+
# Assign ranks (starting from 1)
|
| 13 |
+
rank = 1
|
| 14 |
+
for key, value in sorted_items:
|
| 15 |
+
ranked_dict[key] = rank
|
| 16 |
+
rank += 1
|
| 17 |
+
return ranked_dict
|
| 18 |
+
|
| 19 |
+
def unique_pattern_identifiers(symbol_index_sequences,save=False, name="data"):
|
| 20 |
+
freq_dict = {}
|
| 21 |
+
i = 0
|
| 22 |
+
if os.path.exists(name+".json") and save:
|
| 23 |
+
# Load the existing JSON data into a dictionary
|
| 24 |
+
with open(name+".json", "r") as json_file:
|
| 25 |
+
loaded_dict = json.load(json_file)
|
| 26 |
+
return get_pattern_ranks(loaded_dict)
|
| 27 |
+
|
| 28 |
+
for sequence in symbol_index_sequences:
|
| 29 |
+
raw_pattern_data = find_patterns_and_indices(sequence, remove_subsets=False)
|
| 30 |
+
for pattern in raw_pattern_data:
|
| 31 |
+
key = "-".join(map(str, pattern[0]))
|
| 32 |
+
if key in freq_dict:
|
| 33 |
+
freq_dict[key] += len(pattern[1])
|
| 34 |
+
else:
|
| 35 |
+
freq_dict[key] = len(pattern[1])
|
| 36 |
+
i += 1
|
| 37 |
+
if save:
|
| 38 |
+
with open(name+".json", "w") as json_file:
|
| 39 |
+
json.dump(freq_dict, json_file, indent=4)
|
| 40 |
+
|
| 41 |
+
return get_pattern_ranks(freq_dict)
|
| 42 |
+
|
| 43 |
+
def get_data_pairs(full_text):
|
| 44 |
+
if os.path.exists("data_pairs.pkl"):
|
| 45 |
+
with open("data_pairs.pkl", 'rb') as f:
|
| 46 |
+
print("Loading training pairs from pickle file...")
|
| 47 |
+
return pickle.load(f)
|
| 48 |
+
text_chunks = []
|
| 49 |
+
chunk_len = 1500
|
| 50 |
+
i=0
|
| 51 |
+
print("starting chunking text")
|
| 52 |
+
while i * chunk_len < len(full_text) - chunk_len - 1:
|
| 53 |
+
i += 1
|
| 54 |
+
sample_text = full_text[(i - 1) * chunk_len: i * chunk_len - 1]
|
| 55 |
+
text_chunks.append(sample_text)
|
| 56 |
+
symbol_index_sequences = []
|
| 57 |
+
symbols = ['b', 'j', '\r', 'J', '”', ')', 'Â', 'É', 'ê', '5', 't', '9', 'Y', '%', 'N', 'B', 'V', '\ufeff', 'Ê', '?', '’', 'i', ':', 's', 'C', 'â', 'ï', 'W', 'y', 'p', 'D', '—', '«', 'º', 'A', '3', 'n', '0', 'q', '4', 'e', 'T', 'È', '$', 'U', 'v', '»', 'l', 'P', 'X', 'Z', 'À', 'ç', 'u', '…', 'î', 'L', 'k', 'E', 'R', '2', '_', '8', 'é', 'O', 'Î', '‘', 'a', 'F', 'H', 'c', '[', '(', "'", 'è', 'I', '/', '!', ' ', '°', 'S', '•', '#', 'x', 'à', 'g', '*', 'Q', 'w', '1', 'û', '7', 'G', 'm', '™', 'K', 'z', '\n', 'o', 'ù', ',', 'r', ']', '.', 'M', 'Ç', '“', 'h', '-', 'f', 'ë', '6', ';', 'd', 'ô', 'e ', 's ', 't ', 'es', ' d', '\r\n', 'en', 'qu', ' l', 're', ' p', 'de', 'le', 'nt', 'on', ' c', ', ', ' e', 'ou', ' q', ' s', 'n ', 'ue', 'an', 'te', ' a', 'ai', 'se', 'it', 'me', 'is', 'oi', 'r ', 'er', ' m', 'ce', 'ne', 'et', 'in', 'ns', ' n', 'ur', 'i ', 'a ', 'eu', 'co', 'tr', 'la', 'ar', 'ie', 'ui', 'us', 'ut', 'il', ' t', 'pa', 'au', 'el', 'ti', 'st', 'un', 'em', 'ra', 'e,', 'so', 'or', 'l ', ' f', 'll', 'nd', ' j', 'si', 'ir', 'e\r', 'ss', 'u ', 'po', 'ro', 'ri', 'pr', 's,', 'ma', ' v', ' i', 'di', ' r', 'vo', 'pe', 'to', 'ch', '. ', 've', 'nc', 'om', ' o', 'je', 'no', 'rt', 'à ', 'lu', "'e", 'mo', 'ta', 'as', 'at', 'io', 's\r', 'sa', "u'", 'av', 'os', ' à', ' u', "l'", "'a", 'rs', 'pl', 'é ', '; ', 'ho', 'té', 'ét', 'fa', 'da', 'li', 'su', 't\r', 'ée', 'ré', 'dé', 'ec', 'nn', 'mm', "'i", 'ca', 'uv', '\n\r', 'id', ' b', 'ni', 'bl']
|
| 58 |
+
symbols = load_or_save_symbols(symbols)
|
| 59 |
+
substitution_rule = substitution_cipher(symbols, 1337)
|
| 60 |
+
|
| 61 |
+
def invariate_sequence(sample, ids, vocab_size):
|
| 62 |
+
fill_in = []
|
| 63 |
+
p = find_patterns_and_indices(sample)
|
| 64 |
+
u = find_unique_singles(sample)
|
| 65 |
+
for pattern in p:
|
| 66 |
+
value = "-".join(map(str, pattern[0]))
|
| 67 |
+
for index in pattern[1]:
|
| 68 |
+
fill_in.append([index, ids[value], len(pattern[0])])
|
| 69 |
+
for unique in u:
|
| 70 |
+
fill_in.append([unique[1][0], vocab_size + 1, 0])
|
| 71 |
+
fill_in.sort(key=lambda x: x[0])
|
| 72 |
+
total_list = [0] * 1024
|
| 73 |
+
i=0
|
| 74 |
+
tally = 0
|
| 75 |
+
pattern_count = 0
|
| 76 |
+
while pattern_count < len(fill_in):
|
| 77 |
+
if tally != fill_in[pattern_count][0]:
|
| 78 |
+
total_list[i] = 1
|
| 79 |
+
i+=1
|
| 80 |
+
tally+=1
|
| 81 |
+
continue
|
| 82 |
+
if fill_in[pattern_count][2] == 0:
|
| 83 |
+
total_list[i] = 0
|
| 84 |
+
i+=1
|
| 85 |
+
tally+=1
|
| 86 |
+
pattern_count +=1
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
if fill_in[pattern_count][2] == 1:
|
| 90 |
+
total_list[i+1] = fill_in[pattern_count][1] + 5
|
| 91 |
+
i+=1
|
| 92 |
+
else:
|
| 93 |
+
total_list[i] = fill_in[pattern_count][2]
|
| 94 |
+
total_list[i+1] = fill_in[pattern_count][1] + 5
|
| 95 |
+
i+=2
|
| 96 |
+
|
| 97 |
+
tally += fill_in[pattern_count][2]
|
| 98 |
+
pattern_count += 1
|
| 99 |
+
|
| 100 |
+
total_list = total_list[:i]
|
| 101 |
+
return total_list
|
| 102 |
+
|
| 103 |
+
dataset = []
|
| 104 |
+
print(len(text_chunks))
|
| 105 |
+
i5 = 0
|
| 106 |
+
for text_i in range(len(text_chunks)):
|
| 107 |
+
i5 += 1
|
| 108 |
+
if i5 % 100 == 0:
|
| 109 |
+
print(i5)
|
| 110 |
+
sample_encodings, sample_indices = encode_text_with_indices(substitution_rule, symbols, text_chunks[text_i])
|
| 111 |
+
sample_encodings = sample_encodings[:512]
|
| 112 |
+
sample_indices = sample_indices[:512]
|
| 113 |
+
if sample_indices.count(256) > 0:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
encodings_identifiers = unique_pattern_identifiers([sample_encodings], False)
|
| 117 |
+
encodings_vocab = len(encodings_identifiers.items())
|
| 118 |
+
|
| 119 |
+
encoding_list = invariate_sequence(sample_encodings, encodings_identifiers, encodings_vocab)
|
| 120 |
+
dataset.append([encoding_list, sample_indices])
|
| 121 |
+
with open("data_pairs.pkl", 'wb') as f:
|
| 122 |
+
pickle.dump(dataset, f)
|
| 123 |
+
return dataset
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def filter_subset_pairs(pairs):
|
| 128 |
+
|
| 129 |
+
def is_subset_pair(pair1, pair2):
|
| 130 |
+
first1, second1 = set(pair1[0]), set(pair1[1])
|
| 131 |
+
first2, second2 = set(pair2[0]), set(pair2[1])
|
| 132 |
+
return (first1.issubset(first2) and second1.issubset(second2) and
|
| 133 |
+
(len(first1) < len(first2) or len(second1) < len(second2)))
|
| 134 |
+
|
| 135 |
+
result = pairs.copy()
|
| 136 |
+
i = len(result) - 1
|
| 137 |
+
|
| 138 |
+
while i >= 0:
|
| 139 |
+
should_remove = False
|
| 140 |
+
for j, pair2 in enumerate(result):
|
| 141 |
+
if i != j and is_subset_pair(result[i], pair2):
|
| 142 |
+
should_remove = True
|
| 143 |
+
break
|
| 144 |
+
if should_remove:
|
| 145 |
+
result.pop(i)
|
| 146 |
+
i -= 1
|
| 147 |
+
|
| 148 |
+
return result
|
| 149 |
+
|
| 150 |
+
def find_patterns_and_indices(sequence, remove_subsets=True):
|
| 151 |
+
"""
|
| 152 |
+
Find all repeating subsequences in a sequence and their indices.
|
| 153 |
+
Excludes indices of subsequences when they are part of a larger repeating subsequence.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
sequence (list): Input sequence of numbers
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
list: List of [subsequence, indices] pairs for repeating subsequences
|
| 160 |
+
"""
|
| 161 |
+
n = len(sequence)
|
| 162 |
+
result = []
|
| 163 |
+
|
| 164 |
+
# Helper function to convert list to tuple for hashability
|
| 165 |
+
def to_tuple(lst):
|
| 166 |
+
return tuple(lst)
|
| 167 |
+
|
| 168 |
+
# Find all possible subsequences and their indices
|
| 169 |
+
subsequence_indices = {}
|
| 170 |
+
for length in range(1, 5): # Start from length 2
|
| 171 |
+
for i in range(n - length + 1):
|
| 172 |
+
subseq = to_tuple(sequence[i:i + length])
|
| 173 |
+
if subseq not in subsequence_indices:
|
| 174 |
+
subsequence_indices[subseq] = []
|
| 175 |
+
subsequence_indices[subseq].append(i)
|
| 176 |
+
|
| 177 |
+
# Filter out non-repeating subsequences
|
| 178 |
+
repeating_subsequences = {
|
| 179 |
+
subseq: indices
|
| 180 |
+
for subseq, indices in subsequence_indices.items()
|
| 181 |
+
if len(indices) > 1
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# Sort subsequences by length (longest first)
|
| 185 |
+
sorted_subsequences = sorted(
|
| 186 |
+
repeating_subsequences.items(),
|
| 187 |
+
key=lambda x: len(x[0]),
|
| 188 |
+
reverse=True
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Keep track of used indices
|
| 192 |
+
used_indices = set()
|
| 193 |
+
|
| 194 |
+
# Process subsequences from longest to shortest
|
| 195 |
+
for subseq, indices in sorted_subsequences:
|
| 196 |
+
# Filter out indices that are already part of longer subsequences
|
| 197 |
+
valid_indices = []
|
| 198 |
+
if remove_subsets:
|
| 199 |
+
for idx in indices:
|
| 200 |
+
# Check if any position in this occurrence overlaps with used indices
|
| 201 |
+
overlap = False
|
| 202 |
+
for pos in range(idx, idx + len(subseq)):
|
| 203 |
+
if pos in used_indices:
|
| 204 |
+
overlap = True
|
| 205 |
+
break
|
| 206 |
+
if not overlap:
|
| 207 |
+
valid_indices.append(idx)
|
| 208 |
+
# Mark all positions in this occurrence as used
|
| 209 |
+
for pos in range(idx, idx + len(subseq)):
|
| 210 |
+
used_indices.add(pos)
|
| 211 |
+
else:
|
| 212 |
+
valid_indices = indices
|
| 213 |
+
|
| 214 |
+
# Only add subsequence if it still has multiple valid occurrences
|
| 215 |
+
if len(valid_indices) > 1:
|
| 216 |
+
result.append([list(subseq), valid_indices])
|
| 217 |
+
|
| 218 |
+
return result
|
| 219 |
+
|
| 220 |
+
def find_unique_singles(sequence):
|
| 221 |
+
arr = []
|
| 222 |
+
for i, element in enumerate(sequence):
|
| 223 |
+
count = sequence.count(element)
|
| 224 |
+
if count == 1:
|
| 225 |
+
arr.append([[element], [i]])
|
| 226 |
+
|
| 227 |
+
return arr
|
preprocess_dataset.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cipher_8bit import *
|
| 2 |
+
from french_dataset import get_full_dataset
|
| 3 |
+
|
| 4 |
+
def get_frequency_ranks(encodings, symbols, sequence_len):
|
| 5 |
+
freq_ranks_dict = [0] * len(symbols)
|
| 6 |
+
encodings = encodings[:sequence_len]
|
| 7 |
+
for encoding in encodings:
|
| 8 |
+
freq_ranks_dict[encoding] += 1
|
| 9 |
+
freq_ranks = [0] * (sequence_len)
|
| 10 |
+
|
| 11 |
+
for i in range(len(encodings)):
|
| 12 |
+
freq_ranks[i] = freq_ranks_dict[encodings[i]]
|
| 13 |
+
return freq_ranks
|
| 14 |
+
|
| 15 |
+
def get_proximity_array(encodings, sequence_len):
|
| 16 |
+
distances = [0] * (sequence_len)
|
| 17 |
+
encodings = encodings[:sequence_len]
|
| 18 |
+
for i, encoding in enumerate(encodings):
|
| 19 |
+
try:
|
| 20 |
+
last_idx = encodings.index(encoding, 0, i)
|
| 21 |
+
distances[i] = (i - last_idx)
|
| 22 |
+
except ValueError:
|
| 23 |
+
# If the encoding is not found in the indices, set the distance to 0
|
| 24 |
+
distances[i] = 0
|
| 25 |
+
return distances
|
| 26 |
+
|
| 27 |
+
def preprocess_text(sequence_len=256):
|
| 28 |
+
full_text = get_full_dataset()
|
| 29 |
+
symbols = get_symbols(full_text, 256)
|
| 30 |
+
symbols = load_or_save_symbols(symbols)
|
| 31 |
+
substitution_rule = substitution_cipher(symbols, 1337)
|
| 32 |
+
|
| 33 |
+
i = 0
|
| 34 |
+
raw_length = sequence_len * 2 # Overshoot so when it encodes it takes atleast sequence_len
|
| 35 |
+
processed_data = []
|
| 36 |
+
while i * raw_length < len(full_text) - raw_length:
|
| 37 |
+
i += 1
|
| 38 |
+
sample_text = full_text[(i - 1) * raw_length: i * raw_length - 1]
|
| 39 |
+
encodings_array, indices = encode_text_with_indices(substitution_rule, symbols, sample_text)
|
| 40 |
+
if len(encodings_array) > sequence_len: encodings_array = encodings_array[:sequence_len]
|
| 41 |
+
if len(indices) > sequence_len: indices = indices[:sequence_len]
|
| 42 |
+
|
| 43 |
+
ranks = get_frequency_ranks(encodings_array, symbols, sequence_len)
|
| 44 |
+
distances = get_proximity_array(encodings_array, sequence_len)
|
| 45 |
+
processed_data.append([encodings_array, distances, indices])
|
| 46 |
+
return processed_data
|
| 47 |
+
|
symbols.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5f430187556b38b219487b3e44a9d89cf84fc324807897b2579bda60a12badf
|
| 3 |
+
size 1230
|
test2.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from preprocess_dataset import preprocess_text
|
| 8 |
+
import bitsandbytes as bnb
|
| 9 |
+
from invariants import get_data_pairs
|
| 10 |
+
from french_dataset import get_full_dataset
|
| 11 |
+
from torch.utils.data import Dataset, DataLoader
|
| 12 |
+
from cipher_8bit import load_or_save_symbols, substitution_cipher, encode_text_with_indices
|
| 13 |
+
|
| 14 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
|
| 16 |
+
MAX_SEQUENCE_LENGTH = 512
|
| 17 |
+
PAD_IDX_SRC, PAD_IDX_TGT, BOS_IDX_SRC, BOS_IDX_TGT = 698, 257, 699, 258
|
| 18 |
+
|
| 19 |
+
training_pairs = get_data_pairs(get_full_dataset())
|
| 20 |
+
def create_mask(src):
|
| 21 |
+
src_seq_len = src.shape[0]
|
| 22 |
+
|
| 23 |
+
src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
|
| 24 |
+
|
| 25 |
+
src_padding_mask = (src == PAD_IDX_SRC).transpose(0, 1)
|
| 26 |
+
return src_padding_mask
|
| 27 |
+
class CipherDataset(Dataset):
|
| 28 |
+
def __init__(self, pairs):
|
| 29 |
+
self.encodings = []
|
| 30 |
+
self.symbol_indices = []
|
| 31 |
+
bulk_dataset = pairs
|
| 32 |
+
for entry in bulk_dataset:
|
| 33 |
+
|
| 34 |
+
if len(entry[0]) < MAX_SEQUENCE_LENGTH:
|
| 35 |
+
self.encodings.append([BOS_IDX_SRC] + entry[0] + [PAD_IDX_SRC] * (MAX_SEQUENCE_LENGTH - len(entry[0]) - 1))
|
| 36 |
+
elif len(entry[0]) > MAX_SEQUENCE_LENGTH:
|
| 37 |
+
self.encodings.append([BOS_IDX_SRC] + entry[0][:MAX_SEQUENCE_LENGTH - 1])
|
| 38 |
+
else:
|
| 39 |
+
self.encodings.append([BOS_IDX_SRC] + entry[0][:-1])
|
| 40 |
+
|
| 41 |
+
if len(entry[1]) < MAX_SEQUENCE_LENGTH:
|
| 42 |
+
self.symbol_indices.append([BOS_IDX_TGT] + entry[1] + [PAD_IDX_TGT] * (MAX_SEQUENCE_LENGTH - len(entry[1]) - 1))
|
| 43 |
+
elif len(entry[1]) > MAX_SEQUENCE_LENGTH:
|
| 44 |
+
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:MAX_SEQUENCE_LENGTH - 1])
|
| 45 |
+
else:
|
| 46 |
+
self.symbol_indices.append([BOS_IDX_TGT] + entry[1][:-1])
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.encodings)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
return torch.tensor(self.encodings[idx]), torch.tensor(self.symbol_indices[idx])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Wrap data in the simplest possible way to enable PyTorch data fetching
|
| 56 |
+
# https://pytorch.org/docs/stable/data.html
|
| 57 |
+
|
| 58 |
+
BATCH_SIZE = 64
|
| 59 |
+
TRAIN_FRAC = 0.995
|
| 60 |
+
|
| 61 |
+
dataset = CipherDataset(training_pairs)
|
| 62 |
+
N = len(training_pairs)
|
| 63 |
+
print(N)
|
| 64 |
+
S = 512
|
| 65 |
+
C = 700
|
| 66 |
+
# Split into train and val
|
| 67 |
+
num_train = int(N * TRAIN_FRAC)
|
| 68 |
+
num_val = N - num_train
|
| 69 |
+
data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))
|
| 70 |
+
|
| 71 |
+
dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
|
| 72 |
+
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)
|
| 73 |
+
|
| 74 |
+
# Sample batch
|
| 75 |
+
x, y = next(iter(dataloader_train))
|
| 76 |
+
class PositionalEncoding(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
Classic Attention-is-all-you-need positional encoding.
|
| 79 |
+
From PyTorch docs.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, d_model, dropout=0.1, max_len=512):
|
| 83 |
+
super(PositionalEncoding, self).__init__()
|
| 84 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 85 |
+
|
| 86 |
+
pe = torch.zeros(max_len, d_model)
|
| 87 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 88 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 89 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 90 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 91 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 92 |
+
self.register_buffer('pe', pe)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
x = x + self.pe[:x.size(0), :]
|
| 96 |
+
return self.dropout(x)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def generate_square_subsequent_mask(size: int):
|
| 100 |
+
"""Generate a triangular (size, size) mask. From PyTorch docs."""
|
| 101 |
+
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
|
| 102 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
| 103 |
+
return mask
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class Transformer(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
Classic Transformer that both encodes and decodes.
|
| 109 |
+
|
| 110 |
+
Prediction-time inference is done greedily.
|
| 111 |
+
|
| 112 |
+
NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, num_classes: int, max_output_length: int, dim: int = 192):
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
# Parameters
|
| 119 |
+
self.dim = dim
|
| 120 |
+
self.max_output_length = max_output_length
|
| 121 |
+
nhead = 16
|
| 122 |
+
num_layers = 8
|
| 123 |
+
dim_feedforward = dim
|
| 124 |
+
|
| 125 |
+
# Encoder part
|
| 126 |
+
self.x_embedding = nn.Embedding(700, dim)
|
| 127 |
+
self.y_embedding = nn.Embedding(259, dim)
|
| 128 |
+
self.pos_encoder = PositionalEncoding(d_model=self.dim)
|
| 129 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
| 130 |
+
encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
|
| 131 |
+
num_layers=num_layers
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Decoder part
|
| 135 |
+
|
| 136 |
+
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
|
| 137 |
+
self.transformer_decoder = nn.TransformerDecoder(
|
| 138 |
+
decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
|
| 139 |
+
num_layers=num_layers
|
| 140 |
+
)
|
| 141 |
+
self.fc = nn.Linear(self.dim, 259)
|
| 142 |
+
|
| 143 |
+
# It is empirically important to initialize weights properly
|
| 144 |
+
self.init_weights()
|
| 145 |
+
|
| 146 |
+
def init_weights(self):
|
| 147 |
+
initrange = 0.1
|
| 148 |
+
self.x_embedding.weight.data.uniform_(-initrange, initrange)
|
| 149 |
+
self.y_embedding.weight.data.uniform_(-initrange, initrange)
|
| 150 |
+
self.fc.bias.data.zero_()
|
| 151 |
+
self.fc.weight.data.uniform_(-initrange, initrange)
|
| 152 |
+
|
| 153 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
Input
|
| 156 |
+
x: (B, Sx) with elements in (0, C) where C is num_classes
|
| 157 |
+
y: (B, Sy) with elements in (0, C) where C is num_classes
|
| 158 |
+
Output
|
| 159 |
+
(B, C, Sy) logits
|
| 160 |
+
"""
|
| 161 |
+
x_pad_mask= create_mask(x)
|
| 162 |
+
|
| 163 |
+
encoded_x = self.encode(x, x_pad_mask) # (Sx, B, E)
|
| 164 |
+
output = self.decode(y, encoded_x) # (Sy, B, C)
|
| 165 |
+
return output.permute(1, 2, 0) # (B, C, Sy)
|
| 166 |
+
|
| 167 |
+
def encode(self, x: torch.Tensor, x_pad_mask: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
Input
|
| 170 |
+
x: (B, Sx) with elements in (0, C) where C is num_classes
|
| 171 |
+
Output
|
| 172 |
+
(Sx, B, E) embedding
|
| 173 |
+
"""
|
| 174 |
+
x = x.permute(1, 0) # (Sx, B, E)
|
| 175 |
+
x = self.x_embedding(x) * math.sqrt(self.dim) # (Sx, B, E)
|
| 176 |
+
x = self.pos_encoder(x) # (Sx, B, E)
|
| 177 |
+
x = self.transformer_encoder(x, None, x_pad_mask.transpose(0,1)) # (Sx, B, E)
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
def decode(self, y: torch.Tensor, encoded_x: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
"""
|
| 182 |
+
Input
|
| 183 |
+
encoded_x: (Sx, B, E)
|
| 184 |
+
y: (B, Sy) with elements in (0, C) where C is num_classes
|
| 185 |
+
Output
|
| 186 |
+
(Sy, B, C) logits
|
| 187 |
+
"""
|
| 188 |
+
y = y.permute(1, 0) # (Sy, B)
|
| 189 |
+
y = self.y_embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
|
| 190 |
+
y = self.pos_encoder(y) # (Sy, B, E)
|
| 191 |
+
Sy = y.shape[0]
|
| 192 |
+
y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x) # (Sy, Sy)
|
| 193 |
+
output = self.transformer_decoder(y, encoded_x, y_mask) # (Sy, B, E)
|
| 194 |
+
output = self.fc(output) # (Sy, B, C)
|
| 195 |
+
return output
|
| 196 |
+
|
| 197 |
+
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
"""
|
| 199 |
+
Method to use at inference time. Predict y from x one token at a time. This method is greedy
|
| 200 |
+
decoding. Beam search can be used instead for a potential accuracy boost.
|
| 201 |
+
|
| 202 |
+
Input
|
| 203 |
+
x: (B, Sx) with elements in (0, C) where C is num_classes
|
| 204 |
+
Output
|
| 205 |
+
(B, C, Sy) logits
|
| 206 |
+
"""
|
| 207 |
+
x_pad_mask = create_mask(x)
|
| 208 |
+
|
| 209 |
+
encoded_x = self.encode(x, x_pad_mask)
|
| 210 |
+
|
| 211 |
+
output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() # (B, max_length)
|
| 212 |
+
output_tokens[:, 0] = BOS_IDX_TGT # Set start token
|
| 213 |
+
for Sy in range(1, self.max_output_length):
|
| 214 |
+
y = output_tokens[:, :Sy] # (B, Sy)
|
| 215 |
+
output = self.decode(y, encoded_x) # (Sy, B, C)
|
| 216 |
+
output = torch.argmax(output, dim=-1) # (Sy, B)
|
| 217 |
+
output_tokens[:, Sy] = output[-1:] # Set the last output token
|
| 218 |
+
return output_tokens
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class LitModel(pl.LightningModule):
|
| 222 |
+
"""Simple PyTorch-Lightning model to train our Transformer."""
|
| 223 |
+
|
| 224 |
+
def __init__(self, model):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.model = model
|
| 227 |
+
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=0.2, reduction='mean')
|
| 228 |
+
|
| 229 |
+
def training_step(self, batch, batch_ind):
|
| 230 |
+
x, y = batch
|
| 231 |
+
# Teacher forcing: model gets input up to the last character,
|
| 232 |
+
# while ground truth is from the second character onward.
|
| 233 |
+
logits = self.model(x, y[:, :-1])
|
| 234 |
+
loss = self.loss(logits, y[:, 1:])
|
| 235 |
+
self.log("train_loss", loss)
|
| 236 |
+
return loss
|
| 237 |
+
|
| 238 |
+
def validation_step(self, batch, batch_ind):
|
| 239 |
+
x, y = batch
|
| 240 |
+
logits = self.model(x, y[:, :-1])
|
| 241 |
+
loss = self.loss(logits, y[:, 1:])
|
| 242 |
+
self.log("val_loss", loss, prog_bar=True)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def configure_optimizers(self):
|
| 246 |
+
return bnb.optim.AdamW8bit(self.parameters(), lr=0.0005, betas=(0.9, 0.99), eps=1e-8, weight_decay=0.01)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# We can see that the decoding works correctly
|
| 250 |
+
|
| 251 |
+
x, y = next(iter(dataloader_val))
|
| 252 |
+
|
| 253 |
+
model = Transformer(num_classes=700, max_output_length=y.shape[1])
|
| 254 |
+
lit_model = LitModel(model)
|
| 255 |
+
trainer = pl.Trainer(max_epochs=1)
|
| 256 |
+
trainer.fit(lit_model, dataloader_train, dataloader_val)
|
| 257 |
+
torch.save(model.state_dict(), "cipher.pth")
|
| 258 |
+
|
| 259 |
+
print('Input:', x[:1])
|
| 260 |
+
pred = lit_model.model.predict(x[:1])
|
| 261 |
+
print('Truth/Pred:')
|
| 262 |
+
tokens = torch.cat((y[:1], pred)).cpu().numpy()
|
| 263 |
+
symbols = load_or_save_symbols([])
|
| 264 |
+
print(tokens)
|
| 265 |
+
print(''.join([symbols[x] if x < 256 else "#" for x in tokens[0][1:]]))
|
| 266 |
+
print(''.join([symbols[x] if x < 256 else "#" for x in tokens[1][1:]]))
|
| 267 |
+
|