Xsmos commited on
Commit
f7a0599
·
verified ·
1 Parent(s): 56ca702
Files changed (2) hide show
  1. diffusion.py +2 -2
  2. quantify_results.ipynb +16 -101
diffusion.py CHANGED
@@ -560,7 +560,7 @@ class DDPM21CM:
560
  def save(self, ep):
561
  # save model
562
  # if self.accelerator.is_main_process:
563
- if torch.cuda.current_device() == 0:
564
  if ep == self.config.n_epoch-1 or (ep+1) % self.config.save_period == 0:
565
  self.nn_model.eval()
566
  with torch.no_grad():
@@ -674,7 +674,7 @@ def train(rank, world_size, local_world_size, master_addr, master_port):
674
  config = TrainConfig()
675
  config.device = f"cuda:{rank}"
676
  config.world_size = local_world_size
677
-
678
  #[3200]#[200]#[1600,3200,6400,12800,25600]
679
  #for i, num_image in enumerate(num_train_image_list):
680
  #config.num_image = num_image
 
560
  def save(self, ep):
561
  # save model
562
  # if self.accelerator.is_main_process:
563
+ if self.config.global_rank == 0:# and torch.cuda.current_device() == 0:
564
  if ep == self.config.n_epoch-1 or (ep+1) % self.config.save_period == 0:
565
  self.nn_model.eval()
566
  with torch.no_grad():
 
674
  config = TrainConfig()
675
  config.device = f"cuda:{rank}"
676
  config.world_size = local_world_size
677
+ config.global_rank = global_rank
678
  #[3200]#[200]#[1600,3200,6400,12800,25600]
679
  #for i, num_image in enumerate(num_train_image_list):
680
  #config.num_image = num_image
quantify_results.ipynb CHANGED
@@ -76,121 +76,36 @@
76
  },
77
  {
78
  "cell_type": "code",
79
- "execution_count": 4,
80
  "metadata": {},
81
  "outputs": [
82
  {
83
  "name": "stdout",
84
  "output_type": "stream",
85
  "text": [
86
- "total 1187288\n",
87
- "drwxr-xr-x 163 bxia34 12288 Jul 28 17:27 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
88
- "-rw-r--r-- 1 bxia34 607850537 Jul 28 17:28 model_state-N30-device_count3-epoch4-172.27.145.66\n",
89
- "-rw-r--r-- 1 bxia34 607850537 Jul 28 17:28 model_state-N30-device_count3-epoch4-172.27.145.67\n"
90
  ]
91
  }
92
  ],
93
  "source": [
94
- "ll outputs"
95
- ]
96
- },
97
- {
98
- "cell_type": "code",
99
- "execution_count": 14,
100
- "metadata": {},
101
- "outputs": [],
102
- "source": [
103
  "model0 = torch.load(\"outputs/model_state-N30-device_count3-epoch4-172.27.145.66\")\n",
104
- "model1 = torch.load(\"outputs/model_state-N30-device_count3-epoch4-172.27.145.67\")"
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": 15,
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "data": {
114
- "text/plain": [
115
- "dict_keys(['epoch', 'unet_state_dict'])"
116
- ]
117
- },
118
- "execution_count": 15,
119
- "metadata": {},
120
- "output_type": "execute_result"
121
- }
122
- ],
123
- "source": [
124
- "model0.keys()"
125
- ]
126
- },
127
- {
128
- "cell_type": "code",
129
- "execution_count": 16,
130
- "metadata": {},
131
- "outputs": [
132
- {
133
- "data": {
134
- "text/plain": [
135
- "4"
136
- ]
137
- },
138
- "execution_count": 16,
139
- "metadata": {},
140
- "output_type": "execute_result"
141
- }
142
- ],
143
- "source": [
144
- "model0['epoch']"
145
- ]
146
- },
147
- {
148
- "cell_type": "code",
149
- "execution_count": 17,
150
- "metadata": {},
151
- "outputs": [],
152
- "source": [
153
  "state0 = model0['unet_state_dict']\n",
154
- "state1 = model1['unet_state_dict']"
155
- ]
156
- },
157
- {
158
- "cell_type": "code",
159
- "execution_count": 18,
160
- "metadata": {},
161
- "outputs": [
162
- {
163
- "data": {
164
- "text/plain": [
165
- "odict_keys(['token_embedding.weight', 'token_embedding.bias', 'time_embed.0.weight', 'time_embed.0.bias', 'time_embed.2.weight', 'time_embed.2.bias', 'input_blocks.0.0.weight', 'input_blocks.0.0.bias', 'input_blocks.1.0.in_layers.0.weight', 'input_blocks.1.0.in_layers.0.bias', 'input_blocks.1.0.in_layers.2.weight', 'input_blocks.1.0.in_layers.2.bias', 'input_blocks.1.0.emb_layers.1.weight', 'input_blocks.1.0.emb_layers.1.bias', 'input_blocks.1.0.out_layers.0.weight', 'input_blocks.1.0.out_layers.0.bias', 'input_blocks.1.0.out_layers.3.weight', 'input_blocks.1.0.out_layers.3.bias', 'input_blocks.2.0.in_layers.0.weight', 'input_blocks.2.0.in_layers.0.bias', 'input_blocks.2.0.in_layers.2.weight', 'input_blocks.2.0.in_layers.2.bias', 'input_blocks.2.0.emb_layers.1.weight', 'input_blocks.2.0.emb_layers.1.bias', 'input_blocks.2.0.out_layers.0.weight', 'input_blocks.2.0.out_layers.0.bias', 'input_blocks.2.0.out_layers.3.weight', 'input_blocks.2.0.out_layers.3.bias', 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias', 'input_blocks.4.0.in_layers.0.weight', 'input_blocks.4.0.in_layers.0.bias', 'input_blocks.4.0.in_layers.2.weight', 'input_blocks.4.0.in_layers.2.bias', 'input_blocks.4.0.emb_layers.1.weight', 'input_blocks.4.0.emb_layers.1.bias', 'input_blocks.4.0.out_layers.0.weight', 'input_blocks.4.0.out_layers.0.bias', 'input_blocks.4.0.out_layers.3.weight', 'input_blocks.4.0.out_layers.3.bias', 'input_blocks.5.0.in_layers.0.weight', 'input_blocks.5.0.in_layers.0.bias', 'input_blocks.5.0.in_layers.2.weight', 'input_blocks.5.0.in_layers.2.bias', 'input_blocks.5.0.emb_layers.1.weight', 'input_blocks.5.0.emb_layers.1.bias', 'input_blocks.5.0.out_layers.0.weight', 'input_blocks.5.0.out_layers.0.bias', 'input_blocks.5.0.out_layers.3.weight', 'input_blocks.5.0.out_layers.3.bias', 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias', 'input_blocks.7.0.in_layers.0.weight', 'input_blocks.7.0.in_layers.0.bias', 'input_blocks.7.0.in_layers.2.weight', 'input_blocks.7.0.in_layers.2.bias', 'input_blocks.7.0.emb_layers.1.weight', 'input_blocks.7.0.emb_layers.1.bias', 'input_blocks.7.0.out_layers.0.weight', 'input_blocks.7.0.out_layers.0.bias', 'input_blocks.7.0.out_layers.3.weight', 'input_blocks.7.0.out_layers.3.bias', 'input_blocks.7.0.skip_connection.weight', 'input_blocks.7.0.skip_connection.bias', 'input_blocks.7.1.norm.weight', 'input_blocks.7.1.norm.bias', 'input_blocks.7.1.qkv.weight', 'input_blocks.7.1.qkv.bias', 'input_blocks.7.1.proj_out.weight', 'input_blocks.7.1.proj_out.bias', 'input_blocks.8.0.in_layers.0.weight', 'input_blocks.8.0.in_layers.0.bias', 'input_blocks.8.0.in_layers.2.weight', 'input_blocks.8.0.in_layers.2.bias', 'input_blocks.8.0.emb_layers.1.weight', 'input_blocks.8.0.emb_layers.1.bias', 'input_blocks.8.0.out_layers.0.weight', 'input_blocks.8.0.out_layers.0.bias', 'input_blocks.8.0.out_layers.3.weight', 'input_blocks.8.0.out_layers.3.bias', 'input_blocks.8.1.norm.weight', 'input_blocks.8.1.norm.bias', 'input_blocks.8.1.qkv.weight', 'input_blocks.8.1.qkv.bias', 'input_blocks.8.1.proj_out.weight', 'input_blocks.8.1.proj_out.bias', 'input_blocks.9.0.op.weight', 'input_blocks.9.0.op.bias', 'input_blocks.10.0.in_layers.0.weight', 'input_blocks.10.0.in_layers.0.bias', 'input_blocks.10.0.in_layers.2.weight', 'input_blocks.10.0.in_layers.2.bias', 'input_blocks.10.0.emb_layers.1.weight', 'input_blocks.10.0.emb_layers.1.bias', 'input_blocks.10.0.out_layers.0.weight', 'input_blocks.10.0.out_layers.0.bias', 'input_blocks.10.0.out_layers.3.weight', 'input_blocks.10.0.out_layers.3.bias', 'input_blocks.10.1.norm.weight', 'input_blocks.10.1.norm.bias', 'input_blocks.10.1.qkv.weight', 'input_blocks.10.1.qkv.bias', 'input_blocks.10.1.proj_out.weight', 'input_blocks.10.1.proj_out.bias', 'input_blocks.11.0.in_layers.0.weight', 'input_blocks.11.0.in_layers.0.bias', 'input_blocks.11.0.in_layers.2.weight', 'input_blocks.11.0.in_layers.2.bias', 'input_blocks.11.0.emb_layers.1.weight', 'input_blocks.11.0.emb_layers.1.bias', 'input_blocks.11.0.out_layers.0.weight', 'input_blocks.11.0.out_layers.0.bias', 'input_blocks.11.0.out_layers.3.weight', 'input_blocks.11.0.out_layers.3.bias', 'input_blocks.11.1.norm.weight', 'input_blocks.11.1.norm.bias', 'input_blocks.11.1.qkv.weight', 'input_blocks.11.1.qkv.bias', 'input_blocks.11.1.proj_out.weight', 'input_blocks.11.1.proj_out.bias', 'input_blocks.12.0.op.weight', 'input_blocks.12.0.op.bias', 'input_blocks.13.0.in_layers.0.weight', 'input_blocks.13.0.in_layers.0.bias', 'input_blocks.13.0.in_layers.2.weight', 'input_blocks.13.0.in_layers.2.bias', 'input_blocks.13.0.emb_layers.1.weight', 'input_blocks.13.0.emb_layers.1.bias', 'input_blocks.13.0.out_layers.0.weight', 'input_blocks.13.0.out_layers.0.bias', 'input_blocks.13.0.out_layers.3.weight', 'input_blocks.13.0.out_layers.3.bias', 'input_blocks.14.0.in_layers.0.weight', 'input_blocks.14.0.in_layers.0.bias', 'input_blocks.14.0.in_layers.2.weight', 'input_blocks.14.0.in_layers.2.bias', 'input_blocks.14.0.emb_layers.1.weight', 'input_blocks.14.0.emb_layers.1.bias', 'input_blocks.14.0.out_layers.0.weight', 'input_blocks.14.0.out_layers.0.bias', 'input_blocks.14.0.out_layers.3.weight', 'input_blocks.14.0.out_layers.3.bias', 'middle_block.0.in_layers.0.weight', 'middle_block.0.in_layers.0.bias', 'middle_block.0.in_layers.2.weight', 'middle_block.0.in_layers.2.bias', 'middle_block.0.emb_layers.1.weight', 'middle_block.0.emb_layers.1.bias', 'middle_block.0.out_layers.0.weight', 'middle_block.0.out_layers.0.bias', 'middle_block.0.out_layers.3.weight', 'middle_block.0.out_layers.3.bias', 'middle_block.1.norm.weight', 'middle_block.1.norm.bias', 'middle_block.1.qkv.weight', 'middle_block.1.qkv.bias', 'middle_block.1.proj_out.weight', 'middle_block.1.proj_out.bias', 'middle_block.2.in_layers.0.weight', 'middle_block.2.in_layers.0.bias', 'middle_block.2.in_layers.2.weight', 'middle_block.2.in_layers.2.bias', 'middle_block.2.emb_layers.1.weight', 'middle_block.2.emb_layers.1.bias', 'middle_block.2.out_layers.0.weight', 'middle_block.2.out_layers.0.bias', 'middle_block.2.out_layers.3.weight', 'middle_block.2.out_layers.3.bias', 'output_blocks.0.0.in_layers.0.weight', 'output_blocks.0.0.in_layers.0.bias', 'output_blocks.0.0.in_layers.2.weight', 'output_blocks.0.0.in_layers.2.bias', 'output_blocks.0.0.emb_layers.1.weight', 'output_blocks.0.0.emb_layers.1.bias', 'output_blocks.0.0.out_layers.0.weight', 'output_blocks.0.0.out_layers.0.bias', 'output_blocks.0.0.out_layers.3.weight', 'output_blocks.0.0.out_layers.3.bias', 'output_blocks.0.0.skip_connection.weight', 'output_blocks.0.0.skip_connection.bias', 'output_blocks.1.0.in_layers.0.weight', 'output_blocks.1.0.in_layers.0.bias', 'output_blocks.1.0.in_layers.2.weight', 'output_blocks.1.0.in_layers.2.bias', 'output_blocks.1.0.emb_layers.1.weight', 'output_blocks.1.0.emb_layers.1.bias', 'output_blocks.1.0.out_layers.0.weight', 'output_blocks.1.0.out_layers.0.bias', 'output_blocks.1.0.out_layers.3.weight', 'output_blocks.1.0.out_layers.3.bias', 'output_blocks.1.0.skip_connection.weight', 'output_blocks.1.0.skip_connection.bias', 'output_blocks.2.0.in_layers.0.weight', 'output_blocks.2.0.in_layers.0.bias', 'output_blocks.2.0.in_layers.2.weight', 'output_blocks.2.0.in_layers.2.bias', 'output_blocks.2.0.emb_layers.1.weight', 'output_blocks.2.0.emb_layers.1.bias', 'output_blocks.2.0.out_layers.0.weight', 'output_blocks.2.0.out_layers.0.bias', 'output_blocks.2.0.out_layers.3.weight', 'output_blocks.2.0.out_layers.3.bias', 'output_blocks.2.0.skip_connection.weight', 'output_blocks.2.0.skip_connection.bias', 'output_blocks.2.1.conv.weight', 'output_blocks.2.1.conv.bias', 'output_blocks.3.0.in_layers.0.weight', 'output_blocks.3.0.in_layers.0.bias', 'output_blocks.3.0.in_layers.2.weight', 'output_blocks.3.0.in_layers.2.bias', 'output_blocks.3.0.emb_layers.1.weight', 'output_blocks.3.0.emb_layers.1.bias', 'output_blocks.3.0.out_layers.0.weight', 'output_blocks.3.0.out_layers.0.bias', 'output_blocks.3.0.out_layers.3.weight', 'output_blocks.3.0.out_layers.3.bias', 'output_blocks.3.0.skip_connection.weight', 'output_blocks.3.0.skip_connection.bias', 'output_blocks.3.1.norm.weight', 'output_blocks.3.1.norm.bias', 'output_blocks.3.1.qkv.weight', 'output_blocks.3.1.qkv.bias', 'output_blocks.3.1.proj_out.weight', 'output_blocks.3.1.proj_out.bias', 'output_blocks.4.0.in_layers.0.weight', 'output_blocks.4.0.in_layers.0.bias', 'output_blocks.4.0.in_layers.2.weight', 'output_blocks.4.0.in_layers.2.bias', 'output_blocks.4.0.emb_layers.1.weight', 'output_blocks.4.0.emb_layers.1.bias', 'output_blocks.4.0.out_layers.0.weight', 'output_blocks.4.0.out_layers.0.bias', 'output_blocks.4.0.out_layers.3.weight', 'output_blocks.4.0.out_layers.3.bias', 'output_blocks.4.0.skip_connection.weight', 'output_blocks.4.0.skip_connection.bias', 'output_blocks.4.1.norm.weight', 'output_blocks.4.1.norm.bias', 'output_blocks.4.1.qkv.weight', 'output_blocks.4.1.qkv.bias', 'output_blocks.4.1.proj_out.weight', 'output_blocks.4.1.proj_out.bias', 'output_blocks.5.0.in_layers.0.weight', 'output_blocks.5.0.in_layers.0.bias', 'output_blocks.5.0.in_layers.2.weight', 'output_blocks.5.0.in_layers.2.bias', 'output_blocks.5.0.emb_layers.1.weight', 'output_blocks.5.0.emb_layers.1.bias', 'output_blocks.5.0.out_layers.0.weight', 'output_blocks.5.0.out_layers.0.bias', 'output_blocks.5.0.out_layers.3.weight', 'output_blocks.5.0.out_layers.3.bias', 'output_blocks.5.0.skip_connection.weight', 'output_blocks.5.0.skip_connection.bias', 'output_blocks.5.1.norm.weight', 'output_blocks.5.1.norm.bias', 'output_blocks.5.1.qkv.weight', 'output_blocks.5.1.qkv.bias', 'output_blocks.5.1.proj_out.weight', 'output_blocks.5.1.proj_out.bias', 'output_blocks.5.2.conv.weight', 'output_blocks.5.2.conv.bias', 'output_blocks.6.0.in_layers.0.weight', 'output_blocks.6.0.in_layers.0.bias', 'output_blocks.6.0.in_layers.2.weight', 'output_blocks.6.0.in_layers.2.bias', 'output_blocks.6.0.emb_layers.1.weight', 'output_blocks.6.0.emb_layers.1.bias', 'output_blocks.6.0.out_layers.0.weight', 'output_blocks.6.0.out_layers.0.bias', 'output_blocks.6.0.out_layers.3.weight', 'output_blocks.6.0.out_layers.3.bias', 'output_blocks.6.0.skip_connection.weight', 'output_blocks.6.0.skip_connection.bias', 'output_blocks.6.1.norm.weight', 'output_blocks.6.1.norm.bias', 'output_blocks.6.1.qkv.weight', 'output_blocks.6.1.qkv.bias', 'output_blocks.6.1.proj_out.weight', 'output_blocks.6.1.proj_out.bias', 'output_blocks.7.0.in_layers.0.weight', 'output_blocks.7.0.in_layers.0.bias', 'output_blocks.7.0.in_layers.2.weight', 'output_blocks.7.0.in_layers.2.bias', 'output_blocks.7.0.emb_layers.1.weight', 'output_blocks.7.0.emb_layers.1.bias', 'output_blocks.7.0.out_layers.0.weight', 'output_blocks.7.0.out_layers.0.bias', 'output_blocks.7.0.out_layers.3.weight', 'output_blocks.7.0.out_layers.3.bias', 'output_blocks.7.0.skip_connection.weight', 'output_blocks.7.0.skip_connection.bias', 'output_blocks.7.1.norm.weight', 'output_blocks.7.1.norm.bias', 'output_blocks.7.1.qkv.weight', 'output_blocks.7.1.qkv.bias', 'output_blocks.7.1.proj_out.weight', 'output_blocks.7.1.proj_out.bias', 'output_blocks.8.0.in_layers.0.weight', 'output_blocks.8.0.in_layers.0.bias', 'output_blocks.8.0.in_layers.2.weight', 'output_blocks.8.0.in_layers.2.bias', 'output_blocks.8.0.emb_layers.1.weight', 'output_blocks.8.0.emb_layers.1.bias', 'output_blocks.8.0.out_layers.0.weight', 'output_blocks.8.0.out_layers.0.bias', 'output_blocks.8.0.out_layers.3.weight', 'output_blocks.8.0.out_layers.3.bias', 'output_blocks.8.0.skip_connection.weight', 'output_blocks.8.0.skip_connection.bias', 'output_blocks.8.1.norm.weight', 'output_blocks.8.1.norm.bias', 'output_blocks.8.1.qkv.weight', 'output_blocks.8.1.qkv.bias', 'output_blocks.8.1.proj_out.weight', 'output_blocks.8.1.proj_out.bias', 'output_blocks.8.2.conv.weight', 'output_blocks.8.2.conv.bias', 'output_blocks.9.0.in_layers.0.weight', 'output_blocks.9.0.in_layers.0.bias', 'output_blocks.9.0.in_layers.2.weight', 'output_blocks.9.0.in_layers.2.bias', 'output_blocks.9.0.emb_layers.1.weight', 'output_blocks.9.0.emb_layers.1.bias', 'output_blocks.9.0.out_layers.0.weight', 'output_blocks.9.0.out_layers.0.bias', 'output_blocks.9.0.out_layers.3.weight', 'output_blocks.9.0.out_layers.3.bias', 'output_blocks.9.0.skip_connection.weight', 'output_blocks.9.0.skip_connection.bias', 'output_blocks.10.0.in_layers.0.weight', 'output_blocks.10.0.in_layers.0.bias', 'output_blocks.10.0.in_layers.2.weight', 'output_blocks.10.0.in_layers.2.bias', 'output_blocks.10.0.emb_layers.1.weight', 'output_blocks.10.0.emb_layers.1.bias', 'output_blocks.10.0.out_layers.0.weight', 'output_blocks.10.0.out_layers.0.bias', 'output_blocks.10.0.out_layers.3.weight', 'output_blocks.10.0.out_layers.3.bias', 'output_blocks.10.0.skip_connection.weight', 'output_blocks.10.0.skip_connection.bias', 'output_blocks.11.0.in_layers.0.weight', 'output_blocks.11.0.in_layers.0.bias', 'output_blocks.11.0.in_layers.2.weight', 'output_blocks.11.0.in_layers.2.bias', 'output_blocks.11.0.emb_layers.1.weight', 'output_blocks.11.0.emb_layers.1.bias', 'output_blocks.11.0.out_layers.0.weight', 'output_blocks.11.0.out_layers.0.bias', 'output_blocks.11.0.out_layers.3.weight', 'output_blocks.11.0.out_layers.3.bias', 'output_blocks.11.0.skip_connection.weight', 'output_blocks.11.0.skip_connection.bias', 'output_blocks.11.1.conv.weight', 'output_blocks.11.1.conv.bias', 'output_blocks.12.0.in_layers.0.weight', 'output_blocks.12.0.in_layers.0.bias', 'output_blocks.12.0.in_layers.2.weight', 'output_blocks.12.0.in_layers.2.bias', 'output_blocks.12.0.emb_layers.1.weight', 'output_blocks.12.0.emb_layers.1.bias', 'output_blocks.12.0.out_layers.0.weight', 'output_blocks.12.0.out_layers.0.bias', 'output_blocks.12.0.out_layers.3.weight', 'output_blocks.12.0.out_layers.3.bias', 'output_blocks.12.0.skip_connection.weight', 'output_blocks.12.0.skip_connection.bias', 'output_blocks.13.0.in_layers.0.weight', 'output_blocks.13.0.in_layers.0.bias', 'output_blocks.13.0.in_layers.2.weight', 'output_blocks.13.0.in_layers.2.bias', 'output_blocks.13.0.emb_layers.1.weight', 'output_blocks.13.0.emb_layers.1.bias', 'output_blocks.13.0.out_layers.0.weight', 'output_blocks.13.0.out_layers.0.bias', 'output_blocks.13.0.out_layers.3.weight', 'output_blocks.13.0.out_layers.3.bias', 'output_blocks.13.0.skip_connection.weight', 'output_blocks.13.0.skip_connection.bias', 'output_blocks.14.0.in_layers.0.weight', 'output_blocks.14.0.in_layers.0.bias', 'output_blocks.14.0.in_layers.2.weight', 'output_blocks.14.0.in_layers.2.bias', 'output_blocks.14.0.emb_layers.1.weight', 'output_blocks.14.0.emb_layers.1.bias', 'output_blocks.14.0.out_layers.0.weight', 'output_blocks.14.0.out_layers.0.bias', 'output_blocks.14.0.out_layers.3.weight', 'output_blocks.14.0.out_layers.3.bias', 'output_blocks.14.0.skip_connection.weight', 'output_blocks.14.0.skip_connection.bias', 'out.0.weight', 'out.0.bias', 'out.2.weight', 'out.2.bias'])"
166
- ]
167
- },
168
- "execution_count": 18,
169
- "metadata": {},
170
- "output_type": "execute_result"
171
- }
172
- ],
173
- "source": [
174
- "state0.keys()"
175
- ]
176
- },
177
- {
178
- "cell_type": "code",
179
- "execution_count": 19,
180
- "metadata": {},
181
- "outputs": [
182
- {
183
- "name": "stdout",
184
- "output_type": "stream",
185
- "text": [
186
- "end\n"
187
- ]
188
- }
189
- ],
190
- "source": [
191
  "for key in state0.keys():\n",
192
  " # print(key)\n",
193
- " if not torch.equal(state1[key],state0[key]):\n",
194
  " print(key, \"different\")\n",
195
  " # break\n",
196
  " # else:\n",
 
76
  },
77
  {
78
  "cell_type": "code",
79
+ "execution_count": 20,
80
  "metadata": {},
81
  "outputs": [
82
  {
83
  "name": "stdout",
84
  "output_type": "stream",
85
  "text": [
86
+ "total 1187284\n",
87
+ "drwxr-xr-x 164 bxia34 12288 Jul 28 17:34 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
88
+ "-rw-r--r-- 1 bxia34 607850537 Jul 28 17:35 model_state-N30-device_count3-epoch4-172.27.145.66\n",
89
+ "-rw-r--r-- 1 bxia34 607850537 Jul 28 17:35 model_state-N30-device_count3-epoch4-172.27.145.67\n"
90
  ]
91
  }
92
  ],
93
  "source": [
94
+ "ll outputs\n",
 
 
 
 
 
 
 
 
95
  "model0 = torch.load(\"outputs/model_state-N30-device_count3-epoch4-172.27.145.66\")\n",
96
+ "model1 = torch.load(\"outputs/model_state-N30-device_count3-epoch4-172.27.145.67\")\n",
97
+ "model00 = torch.load(\"outputs/model_state-N30-device_count3-epoch4-172.27.145.66\")\n",
98
+ "model11 = torch.load(\"outputs/model_state-N30-device_count3-epoch4-172.27.145.67\")\n",
99
+ "model0.keys()\n",
100
+ "model0['epoch']\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  "state0 = model0['unet_state_dict']\n",
102
+ "state1 = model1['unet_state_dict']\n",
103
+ "state00 = model00['unet_state_dict']\n",
104
+ "state11 = model11['unet_state_dict']\n",
105
+ "state0.keys()\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  "for key in state0.keys():\n",
107
  " # print(key)\n",
108
+ " if not torch.equal(state00[key],state11[key]):\n",
109
  " print(key, \"different\")\n",
110
  " # break\n",
111
  " # else:\n",