nitt commited on
Commit
8c09ad6
·
1 Parent(s): 43d88f2

Upload LSTM-XOR.ipynb

Browse files
Files changed (1) hide show
  1. LSTM-XOR.ipynb +1 -0
LSTM-XOR.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"code","metadata":{"source_hash":"cc84e8cb","execution_start":1697650504006,"execution_millis":32,"deepnote_to_be_reexecuted":false,"cell_id":"1545444907bd4f1cbcab601f6cf3ad31","deepnote_cell_type":"code"},"source":"import operator\nimport functools\nimport random\n\ndef foldl(func, acc, xs):\n return functools.reduce(func, xs, acc) \n","block_group":"336274182c9942d2bb3c6992efeda573","execution_count":1,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"6becc9d4","execution_start":1697650504460,"execution_millis":22,"deepnote_to_be_reexecuted":false,"cell_id":"574f8a2c047a45c9886f6f3eee577ca5","deepnote_cell_type":"code"},"source":"foldl(operator.add, 0, [1,2,3,4,5,6,7,8,9,10])\n","block_group":"6b61ed77d90847858444d6c56105b2ff","execution_count":2,"outputs":[{"output_type":"execute_result","execution_count":2,"data":{"text/plain":"55"},"metadata":{}}]},{"cell_type":"code","metadata":{"source_hash":"7e397e1f","execution_start":1697650505083,"execution_millis":23,"deepnote_to_be_reexecuted":false,"cell_id":"6cd4a90d6d83455faf22e0d68dc9701d","deepnote_cell_type":"code"},"source":"# parity check, function = xor, acc = 0 sequece, 0s and 1s\nrandom.seed(1)\n\nprint(\" Bitstirng | Parity \")\nprint(\"-\"*34)\n\n#generating a random 12 digit binary string\n\nfor _ in range(1):\n seq = [random.randint(0,1) for _ in range(12)]\n print(f\"{''.join(str(b) for b in seq)} | {foldl(operator.xor, 0, seq)}\")\n\n\n","block_group":"f84e4492348348438ddd81eb3817dc84","execution_count":3,"outputs":[{"name":"stdout","text":" Bitstirng | Parity \n----------------------------------\n001011110010 | 0\n","output_type":"stream"}]},{"cell_type":"code","metadata":{"source_hash":"6909b4f6","execution_start":1697650505834,"execution_millis":45,"deepnote_to_be_reexecuted":false,"cell_id":"efbc6075d0684d0a91029989bc21d0ea","deepnote_cell_type":"code"},"source":"random.seed(1)\n\ndef traceXOR(a,b):\n \"\"\"\n shows the intermediate steps of \n xor function on \n a sequence\n \"\"\"\n\n result = operator.xor(a, b)\n print(f\"{a} XOR {b} = {result}\")\n return result\n\nprint(foldl(traceXOR, 0, [1, 0, 0, 1, 1]))","block_group":"926b7678328b4460bf48112cbb328e13","execution_count":4,"outputs":[{"name":"stdout","text":"0 XOR 1 = 1\n1 XOR 0 = 1\n1 XOR 0 = 1\n1 XOR 1 = 0\n0 XOR 1 = 1\n1\n","output_type":"stream"}]},{"cell_type":"code","metadata":{"source_hash":"83b0a053","execution_start":1697650506606,"execution_millis":34,"deepnote_to_be_reexecuted":false,"cell_id":"7d5b7e3924724b509193cf92b7d57d9b","deepnote_cell_type":"code"},"source":"\"\"\" The math behind the above XOR operations is '0' is the initial accumulator value\nso 0 XOR 1 == 1\nnow the accumulator value has been changed to 1 \nso the second operation\n1 XOR 0 = 1\n1 XOR 0 = 1\n1 XOR 1 = 0\n0 XOR 1 = 1\n1----> is the final accumulator value \n\"\"\"\n","block_group":"e00cfcfe2edb45db8ee52a4c041d709d","execution_count":5,"outputs":[{"output_type":"execute_result","execution_count":5,"data":{"text/plain":"\" The math behind the above XOR operations is '0' is the initial accumulator value\\nso 0 XOR 1 == 1\\nnow the accumulator value has been changed to 1 \\nso the second operation\\n1 XOR 0 = 1\\n1 XOR 0 = 1\\n1 XOR 1 = 0\\n0 XOR 1 = 1\\n1----> is the final accumulator value \\n\""},"metadata":{}}]},{"cell_type":"code","metadata":{"source_hash":"d92d8e78","execution_start":1697650507376,"execution_millis":12,"deepnote_to_be_reexecuted":false,"cell_id":"9e3d0b4fe91e4e3c99e8b413bb72ecb3","deepnote_cell_type":"code"},"source":"import torch\nimport torch.nn as nn\nimport torch.utils.data as data\nfrom torch.utils.data import DataLoader\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'","block_group":"8820493545e34761b14d80ee0feea9aa","execution_count":6,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"474968de","execution_start":1697650508100,"execution_millis":16,"deepnote_to_be_reexecuted":false,"cell_id":"a437ff3d838e4179a933964269e6a16e","deepnote_cell_type":"code"},"source":"#DATA\nTraining_Size = 100000\nValidation_Size = 10000\nBit_len = 50\nVariable_Len = True\n\n# Model Parameters\nInputSize = 1\nHiddenSize = 2\nNumberLayers = 1\n\n#Training Parameters\nBatchSize = 8\nEpochs = 8\nLearningRate = 0.01 #Default ADAM = 0.001\nThreshold = 0.0001","block_group":"83dc9716560c453fb178b71603ce8458","execution_count":7,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"4dc228c8","execution_start":1697652322927,"execution_millis":22,"deepnote_to_be_reexecuted":false,"cell_id":"5b59064196fe4b26ae7f000801c67de0","deepnote_cell_type":"code"},"source":"class XOR(data.Dataset):\n def __init__(self,sample_size = Validation_Size, bit_len = Bit_len, variable=False):\n self.bit_len = Bit_len\n self.sample_size = Validation_Size\n self.variable = Variable_Len\n self.features, self.labels = self.generate_data(sample_size,bit_len)\n\n\n def __getitem__(self, index):\n return self.features[index, :], self.labels[index]\n \n def __len__(self):\n return len(self.features)\n\n def generate_data(self,sample_size, seq_length = Bit_len):\n bits = torch.randint(2, size=(sample_size, seq_length, 1)).float()\n if self.variable:\n # we generate random integers and pad the bits with zeros\n # to mimic variable bit string lengths \n # padding with zeros as they do not provide information\n \n # pad = torch.randint(seq_length, size=(sample_size, ))\n\n # for idx, p in enumerate(pad):\n # bits[idx, p:] = 0.\n\n # TODO: vectorize instead of loop?\n\n # Generate random integers for padding positions\n pad = torch.randint(seq_length, size=(sample_size, 1))\n \"\"\"\n # Create a mask for zero-padding\n mask = torch.arange(seq_length).expand(sample_size, seq_length) >= pad\n\n # Apply the mask to set values after the padding positions to 0\n bits = bits * mask.float()\n\n bitsum = bits.cumsum(axis=1)\n \"\"\"\n # if bitsum[i] odd: -> True\n # else: False\n for idx, p in enumerate(pad):\n bits[idx, p:] = 0.\n\n bitsum = bits.cumsum(axis=1)\n parity = (bitsum % 2 != 0).float()\n \n return bits, parity\n\n\n","block_group":"706f2f18b9d24f959aad6cbb8057b51d","execution_count":33,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"9ba93be5","execution_start":1697652324561,"execution_millis":13,"deepnote_to_be_reexecuted":false,"cell_id":"2a73e999638d477e98927fec66d55fb5","deepnote_cell_type":"code"},"source":"# sample_size = 10\n# seq_length = 2\n# bits = torch.randint(2, size=(sample_size, seq_length, 1)).float()\n# print(bits)","block_group":"6f8f3724d5df4cceb73c5f6dbc496102","execution_count":34,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"b4c8bf82","execution_start":1697652325411,"execution_millis":17,"deepnote_to_be_reexecuted":false,"cell_id":"25c09b015c4145d589f58507748dcb3a","deepnote_cell_type":"code"},"source":"class XORLSTM(nn.Module):\n def __init__(self, input_size, hidden_size, number_layers):\n super(XORLSTM, self).__init__()\n self.hidden_size = hidden_size\n self.number_layers = number_layers\n\n self.lstm = nn.LSTM(input_size, hidden_size, number_layers, batch_first=True)\n self.fc = nn.Linear(hidden_size, 1)\n self.activation = nn.Sigmoid()\n\n\n def forward(self, x, lengths=True):\n h0 = torch.zeros(self.number_layers, x.size(0), self.hidden_size).to(device)\n c0 = torch.zeros(self.number_layers, x.size(0), self.hidden_size).to(device)\n out_lstm, _ = self.lstm(x,(h0,c0))\n out = self.fc(out_lstm)\n predictions = self.activation(out)\n return predictions\n \n\n","block_group":"c4c0ea2414e04c9eac8f60531bb3887d","execution_count":35,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"12f24766","execution_start":1697652325892,"execution_millis":19,"deepnote_to_be_reexecuted":false,"cell_id":"376ff2eaaadd4cc7b2ab1aef07ed1c03","deepnote_cell_type":"code"},"source":"model = XORLSTM(InputSize, HiddenSize, NumberLayers).to(device)\ncriterion = nn.BCELoss()\noptimizer = torch.optim.Adam(model.parameters(), lr=LearningRate)","block_group":"b7e26ff7803648859fd772dd35a6cdbd","execution_count":36,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"9b12bb25","execution_start":1697652577767,"execution_millis":19,"deepnote_to_be_reexecuted":false,"cell_id":"099f01388f924532b6de3eeb4dbb9bc5","deepnote_cell_type":"code"},"source":"def train():\n model.train()\n train_loader = DataLoader(XOR(Training_Size,Bit_len,Variable_Len), batch_size=BatchSize)\n total_step = len(train_loader)\n\n print(\"Training...\\n\")\n print('-'*60)\n\n for epoch in range(1, Epochs+1):\n for step, (features, labels) in enumerate(train_loader):\n features, labels= features.to(device), labels.to(device)\n #forward pass\n output = model(features)\n loss = criterion(output, labels)\n\n #backward pass\n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n accuracy = ((output>0.5)==(labels>0.5)).type(torch.FloatTensor).mean()\n\n if (step+1) % 250==0:\n print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.3f}' \n .format(epoch, Epochs, step+1, total_step, loss.item(), accuracy))\n print('-'*60)\n\n if abs(accuracy-1.0)<Threshold:\n print(\"Early Stopping\")\n return\n\n if (step+1)==total_step:\n valid_accuracy = validate(model)\n print(\"validation accuracy: {:.4f}\".format(valid_accuracy))\n print('-'*60)\n if abs(valid_accuracy - 1.0) < THRESHOLD:\n print(\"EARLY STOPPING\")\n return \n","block_group":"46d771d5b54e4f47a22b891a33b818f0","execution_count":52,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"d0af477c","execution_start":1697652578505,"execution_millis":69,"deepnote_to_be_reexecuted":false,"cell_id":"29015ebeb9e74f19911b15a2dda36128","deepnote_cell_type":"code"},"source":"def validate(model):\n valid_loader = DataLoader(\n XOR(Validation_Size, Bit_len, Variable_Len), \n batch_size=BatchSize\n )\n model.eval()\n correct = 0.\n total = 0.\n for features, labels in valid_loader:\n features, labels = features.to(device), labels.to(device)\n\n with torch.no_grad():\n outputs = model(features)\n total += labels.size(0)*labels.size(1)\n correct += ((outputs > 0.5) == (labels > 0.5)).sum().item()\n return correct / total","block_group":"dd15ffd149fd4a2e93d22963ceaca6e8","execution_count":53,"outputs":[]},{"cell_type":"code","metadata":{"source_hash":"661bb006","execution_start":1697652579092,"execution_millis":72772,"deepnote_to_be_reexecuted":false,"cell_id":"1253242104294cbb9c511fde630c2a2d","deepnote_cell_type":"code"},"source":"train()","block_group":"b1244948882b451bb82da78557a08e82","execution_count":54,"outputs":[{"name":"stdout","text":"Training...\n\n------------------------------------------------------------\nEpoch [1/8], Step [250/12500], Loss: 0.6901, Accuracy: 0.530\n------------------------------------------------------------\nEpoch [1/8], Step [500/12500], Loss: 0.6941, Accuracy: 0.355\n------------------------------------------------------------\nEpoch [1/8], Step [750/12500], Loss: 0.6986, Accuracy: 0.370\n------------------------------------------------------------\nEpoch [1/8], Step [1000/12500], Loss: 0.6681, Accuracy: 0.700\n------------------------------------------------------------\nEpoch [1/8], Step [1250/12500], Loss: 0.6953, Accuracy: 0.475\n------------------------------------------------------------\nEpoch [1/8], Step [1500/12500], Loss: 0.6885, Accuracy: 0.610\n------------------------------------------------------------\nEpoch [1/8], Step [1750/12500], Loss: 0.6905, Accuracy: 0.477\n------------------------------------------------------------\nEpoch [1/8], Step [2000/12500], Loss: 0.6889, Accuracy: 0.498\n------------------------------------------------------------\nEpoch [1/8], Step [2250/12500], Loss: 0.6775, Accuracy: 0.598\n------------------------------------------------------------\nEpoch [1/8], Step [2500/12500], Loss: 0.0805, Accuracy: 1.000\n------------------------------------------------------------\nEarly Stopping\n","output_type":"stream"}]},{"cell_type":"code","metadata":{"source_hash":"d991cab6","execution_start":1697652688587,"execution_millis":16,"deepnote_to_be_reexecuted":false,"cell_id":"786e76112b9a4ac4b3f714ff47c380cc","deepnote_cell_type":"code"},"source":"model","block_group":"8ee631856f6e44d8a282a4879082f24d","execution_count":55,"outputs":[{"output_type":"execute_result","execution_count":55,"data":{"text/plain":"XORLSTM(\n (lstm): LSTM(1, 2, batch_first=True)\n (fc): Linear(in_features=2, out_features=1, bias=True)\n (activation): Sigmoid()\n)"},"metadata":{}}]},{"cell_type":"code","metadata":{"source_hash":"7cb97d10","execution_start":1697652725564,"execution_millis":19,"deepnote_to_be_reexecuted":false,"cell_id":"01276f21ccb04d5098f8a214183bba76","deepnote_cell_type":"code"},"source":"model(XOR(1, 2).generate_data(1)[0]).size()","block_group":"70ec5c815ef5442ea6d3dc6acfc58ed3","execution_count":57,"outputs":[{"output_type":"execute_result","execution_count":57,"data":{"text/plain":"torch.Size([1, 50, 1])"},"metadata":{}}]},{"cell_type":"code","metadata":{"cell_id":"f02ac13a4edc44efa2ea8ce00dc3935e","deepnote_cell_type":"code"},"source":"","block_group":"42601e922f7642608ca954360c8dfada","execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=17d2748e-c8ac-4922-b92b-a46154c07520' target=\"_blank\">\n<img alt='Created in deepnote.com' style='display:inline;max-height:16px;margin:0px;margin-right:7.5px;' src='data:image/svg+xml;base64,PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0iVVRGLTgiPz4KPHN2ZyB3aWR0aD0iODBweCIgaGVpZ2h0PSI4MHB4IiB2aWV3Qm94PSIwIDAgODAgODAiIHZlcnNpb249IjEuMSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIiB4bWxuczp4bGluaz0iaHR0cDovL3d3dy53My5vcmcvMTk5OS94bGluayI+CiAgICA8IS0tIEdlbmVyYXRvcjogU2tldGNoIDU0LjEgKDc2NDkwKSAtIGh0dHBzOi8vc2tldGNoYXBwLmNvbSAtLT4KICAgIDx0aXRsZT5Hcm91cCAzPC90aXRsZT4KICAgIDxkZXNjPkNyZWF0ZWQgd2l0aCBTa2V0Y2guPC9kZXNjPgogICAgPGcgaWQ9IkxhbmRpbmciIHN0cm9rZT0ibm9uZSIgc3Ryb2tlLXdpZHRoPSIxIiBmaWxsPSJub25lIiBmaWxsLXJ1bGU9ImV2ZW5vZGQiPgogICAgICAgIDxnIGlkPSJBcnRib2FyZCIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoLTEyMzUuMDAwMDAwLCAtNzkuMDAwMDAwKSI+CiAgICAgICAgICAgIDxnIGlkPSJHcm91cC0zIiB0cmFuc2Zvcm09InRyYW5zbGF0ZSgxMjM1LjAwMDAwMCwgNzkuMDAwMDAwKSI+CiAgICAgICAgICAgICAgICA8cG9seWdvbiBpZD0iUGF0aC0yMCIgZmlsbD0iIzAyNjVCNCIgcG9pbnRzPSIyLjM3NjIzNzYyIDgwIDM4LjA0NzY2NjcgODAgNTcuODIxNzgyMiA3My44MDU3NTkyIDU3LjgyMTc4MjIgMzIuNzU5MjczOSAzOS4xNDAyMjc4IDMxLjY4MzE2ODMiPjwvcG9seWdvbj4KICAgICAgICAgICAgICAgIDxwYXRoIGQ9Ik0zNS4wMDc3MTgsODAgQzQyLjkwNjIwMDcsNzYuNDU0OTM1OCA0Ny41NjQ5MTY3LDcxLjU0MjI2NzEgNDguOTgzODY2LDY1LjI2MTk5MzkgQzUxLjExMjI4OTksNTUuODQxNTg0MiA0MS42NzcxNzk1LDQ5LjIxMjIyODQgMjUuNjIzOTg0Niw0OS4yMTIyMjg0IEMyNS40ODQ5Mjg5LDQ5LjEyNjg0NDggMjkuODI2MTI5Niw0My4yODM4MjQ4IDM4LjY0NzU4NjksMzEuNjgzMTY4MyBMNzIuODcxMjg3MSwzMi41NTQ0MjUgTDY1LjI4MDk3Myw2Ny42NzYzNDIxIEw1MS4xMTIyODk5LDc3LjM3NjE0NCBMMzUuMDA3NzE4LDgwIFoiIGlkPSJQYXRoLTIyIiBmaWxsPSIjMDAyODY4Ij48L3BhdGg+CiAgICAgICAgICAgICAgICA8cGF0aCBkPSJNMCwzNy43MzA0NDA1IEwyNy4xMTQ1MzcsMC4yNTcxMTE0MzYgQzYyLjM3MTUxMjMsLTEuOTkwNzE3MDEgODAsMTAuNTAwMzkyNyA4MCwzNy43MzA0NDA1IEM4MCw2NC45NjA0ODgyIDY0Ljc3NjUwMzgsNzkuMDUwMzQxNCAzNC4zMjk1MTEzLDgwIEM0Ny4wNTUzNDg5LDc3LjU2NzA4MDggNTMuNDE4MjY3Nyw3MC4zMTM2MTAzIDUzLjQxODI2NzcsNTguMjM5NTg4NSBDNTMuNDE4MjY3Nyw0MC4xMjg1NTU3IDM2LjMwMzk1NDQsMzcuNzMwNDQwNSAyNS4yMjc0MTcsMzcuNzMwNDQwNSBDMTcuODQzMDU4NiwzNy43MzA0NDA1IDkuNDMzOTE5NjYsMzcuNzMwNDQwNSAwLDM3LjczMDQ0MDUgWiIgaWQ9IlBhdGgtMTkiIGZpbGw9IiMzNzkzRUYiPjwvcGF0aD4KICAgICAgICAgICAgPC9nPgogICAgICAgIDwvZz4KICAgIDwvZz4KPC9zdmc+' > </img>\nCreated in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>","metadata":{"created_in_deepnote_cell":true,"deepnote_cell_type":"markdown"}}],"nbformat":4,"nbformat_minor":0,"metadata":{"deepnote":{},"orig_nbformat":2,"deepnote_notebook_id":"957704fd8f424d52b8e1ca71dd4e52bd","deepnote_execution_queue":[]}}