Spaces:
Runtime error
Runtime error
style: reformat
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tools/inference/log_inference_samples.ipynb
CHANGED
|
@@ -31,11 +31,14 @@
|
|
| 31 |
"metadata": {},
|
| 32 |
"outputs": [],
|
| 33 |
"source": [
|
| 34 |
-
"run_ids = [
|
| 35 |
-
"ENTITY, PROJECT =
|
| 36 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID =
|
| 37 |
-
"
|
| 38 |
-
"
|
|
|
|
|
|
|
|
|
|
| 39 |
"add_clip_32 = False"
|
| 40 |
]
|
| 41 |
},
|
|
@@ -63,8 +66,8 @@
|
|
| 63 |
"num_images = 128\n",
|
| 64 |
"top_k = 8\n",
|
| 65 |
"text_normalizer = TextNormalizer()\n",
|
| 66 |
-
"padding_item =
|
| 67 |
-
"seed = random.randint(0, 2**32-1)\n",
|
| 68 |
"key = jax.random.PRNGKey(seed)\n",
|
| 69 |
"api = wandb.Api()"
|
| 70 |
]
|
|
@@ -100,12 +103,15 @@
|
|
| 100 |
"def p_decode(indices, params):\n",
|
| 101 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 102 |
"\n",
|
|
|
|
| 103 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 104 |
"def p_clip16(inputs, params):\n",
|
| 105 |
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
| 106 |
" return logits\n",
|
| 107 |
"\n",
|
|
|
|
| 108 |
"if add_clip_32:\n",
|
|
|
|
| 109 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 110 |
" def p_clip32(inputs, params):\n",
|
| 111 |
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
|
@@ -119,13 +125,13 @@
|
|
| 119 |
"metadata": {},
|
| 120 |
"outputs": [],
|
| 121 |
"source": [
|
| 122 |
-
"with open(
|
| 123 |
" samples = [l.strip() for l in f.readlines()]\n",
|
| 124 |
" # make list multiple of batch_size by adding elements\n",
|
| 125 |
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
| 126 |
" samples.extend(samples_to_add)\n",
|
| 127 |
" # reshape\n",
|
| 128 |
-
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
| 129 |
]
|
| 130 |
},
|
| 131 |
{
|
|
@@ -138,9 +144,17 @@
|
|
| 138 |
"def get_artifact_versions(run_id, latest_only=False):\n",
|
| 139 |
" try:\n",
|
| 140 |
" if latest_only:\n",
|
| 141 |
-
" return [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
" else:\n",
|
| 143 |
-
" return api.artifact_versions(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
" except:\n",
|
| 145 |
" return []"
|
| 146 |
]
|
|
@@ -153,7 +167,7 @@
|
|
| 153 |
"outputs": [],
|
| 154 |
"source": [
|
| 155 |
"def get_training_config(run_id):\n",
|
| 156 |
-
" training_run = api.run(f
|
| 157 |
" config = training_run.config\n",
|
| 158 |
" return config"
|
| 159 |
]
|
|
@@ -168,8 +182,8 @@
|
|
| 168 |
"# retrieve inference run details\n",
|
| 169 |
"def get_last_inference_version(run_id):\n",
|
| 170 |
" try:\n",
|
| 171 |
-
" inference_run = api.run(f
|
| 172 |
-
" return inference_run.summary.get(
|
| 173 |
" except:\n",
|
| 174 |
" return None"
|
| 175 |
]
|
|
@@ -183,7 +197,6 @@
|
|
| 183 |
"source": [
|
| 184 |
"# compile functions - needed only once per run\n",
|
| 185 |
"def pmap_model_function(model):\n",
|
| 186 |
-
" \n",
|
| 187 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 188 |
" def _generate(tokenized_prompt, key, params):\n",
|
| 189 |
" return model.generate(\n",
|
|
@@ -195,7 +208,7 @@
|
|
| 195 |
" top_k=gen_top_k,\n",
|
| 196 |
" top_p=gen_top_p\n",
|
| 197 |
" )\n",
|
| 198 |
-
"
|
| 199 |
" return _generate"
|
| 200 |
]
|
| 201 |
},
|
|
@@ -222,13 +235,21 @@
|
|
| 222 |
"training_config = get_training_config(run_id)\n",
|
| 223 |
"run = None\n",
|
| 224 |
"p_generate = None\n",
|
| 225 |
-
"model_files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
"for artifact in artifact_versions:\n",
|
| 227 |
-
" print(f
|
| 228 |
" version = int(artifact.version[1:])\n",
|
| 229 |
" results16, results32 = [], []\n",
|
| 230 |
-
" columns = [
|
| 231 |
-
"
|
| 232 |
" if latest_only:\n",
|
| 233 |
" assert last_inference_version is None or version > last_inference_version\n",
|
| 234 |
" else:\n",
|
|
@@ -236,14 +257,23 @@
|
|
| 236 |
" # we should start from v0\n",
|
| 237 |
" assert version == 0\n",
|
| 238 |
" elif version <= last_inference_version:\n",
|
| 239 |
-
" print(
|
|
|
|
|
|
|
| 240 |
" else:\n",
|
| 241 |
" # check we are logging the correct version\n",
|
| 242 |
" assert version == last_inference_version + 1\n",
|
| 243 |
"\n",
|
| 244 |
" # start/resume corresponding run\n",
|
| 245 |
" if run is None:\n",
|
| 246 |
-
" run = wandb.init(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
"\n",
|
| 248 |
" # work in temporary directory\n",
|
| 249 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
|
@@ -264,64 +294,109 @@
|
|
| 264 |
"\n",
|
| 265 |
" # process one batch of captions\n",
|
| 266 |
" for batch in tqdm(samples):\n",
|
| 267 |
-
" processed_prompts =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
"\n",
|
| 269 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
| 270 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
| 271 |
-
" tokenized_prompt = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
" tokenized_prompt = shard(tokenized_prompt)\n",
|
| 273 |
"\n",
|
| 274 |
" # generate images\n",
|
| 275 |
" images = []\n",
|
| 276 |
-
" pbar = tqdm(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
" for i in pbar:\n",
|
| 278 |
" key, subkey = jax.random.split(key)\n",
|
| 279 |
-
" encoded_images = p_generate(
|
|
|
|
|
|
|
| 280 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 281 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
| 282 |
-
" decoded_images = decoded_images.clip(0
|
|
|
|
|
|
|
| 283 |
" for img in decoded_images:\n",
|
| 284 |
-
" images.append(
|
|
|
|
|
|
|
| 285 |
"\n",
|
| 286 |
-
" def add_clip_results(results, processor, p_clip, clip_params)
|
| 287 |
-
" clip_inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
| 289 |
-
" images_per_prompt_indices = np.asarray(
|
| 290 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
" clip_inputs = shard(clip_inputs)\n",
|
| 292 |
" logits = p_clip(clip_inputs, clip_params)\n",
|
| 293 |
" logits = logits.reshape(-1, num_images)\n",
|
| 294 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 295 |
" logits = jax.device_get(logits)\n",
|
| 296 |
" # add to results table\n",
|
| 297 |
-
" for i, (idx, scores, sample) in enumerate(
|
| 298 |
-
"
|
|
|
|
|
|
|
|
|
|
| 299 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 300 |
-
" top_images = [
|
|
|
|
|
|
|
|
|
|
| 301 |
" results.append([sample] + top_images)\n",
|
| 302 |
-
"
|
| 303 |
" # get clip scores\n",
|
| 304 |
-
" pbar.set_description(
|
| 305 |
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
| 306 |
-
"
|
| 307 |
" # get clip 32 scores\n",
|
| 308 |
" if add_clip_32:\n",
|
| 309 |
-
" pbar.set_description(
|
| 310 |
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
| 311 |
"\n",
|
| 312 |
" pbar.close()\n",
|
| 313 |
"\n",
|
| 314 |
-
" \n",
|
| 315 |
-
"\n",
|
| 316 |
" # log results\n",
|
| 317 |
" table = wandb.Table(columns=columns, data=results16)\n",
|
| 318 |
-
" run.log({
|
| 319 |
" wandb.finish()\n",
|
| 320 |
-
"
|
| 321 |
-
" if add_clip_32
|
| 322 |
-
" run = wandb.init(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
" table = wandb.Table(columns=columns, data=results32)\n",
|
| 324 |
-
" run.log({
|
| 325 |
" wandb.finish()\n",
|
| 326 |
" run = None # ensure we don't log on this run"
|
| 327 |
]
|
|
|
|
| 31 |
"metadata": {},
|
| 32 |
"outputs": [],
|
| 33 |
"source": [
|
| 34 |
+
"run_ids = [\"63otg87g\"]\n",
|
| 35 |
+
"ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
|
| 36 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
| 37 |
+
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
| 38 |
+
" \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
|
| 39 |
+
")\n",
|
| 40 |
+
"latest_only = True # log only latest or all versions\n",
|
| 41 |
+
"suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
|
| 42 |
"add_clip_32 = False"
|
| 43 |
]
|
| 44 |
},
|
|
|
|
| 66 |
"num_images = 128\n",
|
| 67 |
"top_k = 8\n",
|
| 68 |
"text_normalizer = TextNormalizer()\n",
|
| 69 |
+
"padding_item = \"NONE\"\n",
|
| 70 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
| 71 |
"key = jax.random.PRNGKey(seed)\n",
|
| 72 |
"api = wandb.Api()"
|
| 73 |
]
|
|
|
|
| 103 |
"def p_decode(indices, params):\n",
|
| 104 |
" return vqgan.decode_code(indices, params=params)\n",
|
| 105 |
"\n",
|
| 106 |
+
"\n",
|
| 107 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
| 108 |
"def p_clip16(inputs, params):\n",
|
| 109 |
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
| 110 |
" return logits\n",
|
| 111 |
"\n",
|
| 112 |
+
"\n",
|
| 113 |
"if add_clip_32:\n",
|
| 114 |
+
"\n",
|
| 115 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 116 |
" def p_clip32(inputs, params):\n",
|
| 117 |
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
|
|
|
| 125 |
"metadata": {},
|
| 126 |
"outputs": [],
|
| 127 |
"source": [
|
| 128 |
+
"with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
|
| 129 |
" samples = [l.strip() for l in f.readlines()]\n",
|
| 130 |
" # make list multiple of batch_size by adding elements\n",
|
| 131 |
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
| 132 |
" samples.extend(samples_to_add)\n",
|
| 133 |
" # reshape\n",
|
| 134 |
+
" samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
|
| 135 |
]
|
| 136 |
},
|
| 137 |
{
|
|
|
|
| 144 |
"def get_artifact_versions(run_id, latest_only=False):\n",
|
| 145 |
" try:\n",
|
| 146 |
" if latest_only:\n",
|
| 147 |
+
" return [\n",
|
| 148 |
+
" api.artifact(\n",
|
| 149 |
+
" type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
|
| 150 |
+
" )\n",
|
| 151 |
+
" ]\n",
|
| 152 |
" else:\n",
|
| 153 |
+
" return api.artifact_versions(\n",
|
| 154 |
+
" type_name=\"bart_model\",\n",
|
| 155 |
+
" name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
|
| 156 |
+
" per_page=10000,\n",
|
| 157 |
+
" )\n",
|
| 158 |
" except:\n",
|
| 159 |
" return []"
|
| 160 |
]
|
|
|
|
| 167 |
"outputs": [],
|
| 168 |
"source": [
|
| 169 |
"def get_training_config(run_id):\n",
|
| 170 |
+
" training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
|
| 171 |
" config = training_run.config\n",
|
| 172 |
" return config"
|
| 173 |
]
|
|
|
|
| 182 |
"# retrieve inference run details\n",
|
| 183 |
"def get_last_inference_version(run_id):\n",
|
| 184 |
" try:\n",
|
| 185 |
+
" inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
|
| 186 |
+
" return inference_run.summary.get(\"version\", None)\n",
|
| 187 |
" except:\n",
|
| 188 |
" return None"
|
| 189 |
]
|
|
|
|
| 197 |
"source": [
|
| 198 |
"# compile functions - needed only once per run\n",
|
| 199 |
"def pmap_model_function(model):\n",
|
|
|
|
| 200 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
| 201 |
" def _generate(tokenized_prompt, key, params):\n",
|
| 202 |
" return model.generate(\n",
|
|
|
|
| 208 |
" top_k=gen_top_k,\n",
|
| 209 |
" top_p=gen_top_p\n",
|
| 210 |
" )\n",
|
| 211 |
+
"\n",
|
| 212 |
" return _generate"
|
| 213 |
]
|
| 214 |
},
|
|
|
|
| 235 |
"training_config = get_training_config(run_id)\n",
|
| 236 |
"run = None\n",
|
| 237 |
"p_generate = None\n",
|
| 238 |
+
"model_files = [\n",
|
| 239 |
+
" \"config.json\",\n",
|
| 240 |
+
" \"flax_model.msgpack\",\n",
|
| 241 |
+
" \"merges.txt\",\n",
|
| 242 |
+
" \"special_tokens_map.json\",\n",
|
| 243 |
+
" \"tokenizer.json\",\n",
|
| 244 |
+
" \"tokenizer_config.json\",\n",
|
| 245 |
+
" \"vocab.json\",\n",
|
| 246 |
+
"]\n",
|
| 247 |
"for artifact in artifact_versions:\n",
|
| 248 |
+
" print(f\"Processing artifact: {artifact.name}\")\n",
|
| 249 |
" version = int(artifact.version[1:])\n",
|
| 250 |
" results16, results32 = [], []\n",
|
| 251 |
+
" columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
|
| 252 |
+
"\n",
|
| 253 |
" if latest_only:\n",
|
| 254 |
" assert last_inference_version is None or version > last_inference_version\n",
|
| 255 |
" else:\n",
|
|
|
|
| 257 |
" # we should start from v0\n",
|
| 258 |
" assert version == 0\n",
|
| 259 |
" elif version <= last_inference_version:\n",
|
| 260 |
+
" print(\n",
|
| 261 |
+
" f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
|
| 262 |
+
" )\n",
|
| 263 |
" else:\n",
|
| 264 |
" # check we are logging the correct version\n",
|
| 265 |
" assert version == last_inference_version + 1\n",
|
| 266 |
"\n",
|
| 267 |
" # start/resume corresponding run\n",
|
| 268 |
" if run is None:\n",
|
| 269 |
+
" run = wandb.init(\n",
|
| 270 |
+
" job_type=\"inference\",\n",
|
| 271 |
+
" entity=\"dalle-mini\",\n",
|
| 272 |
+
" project=\"dalle-mini\",\n",
|
| 273 |
+
" config=training_config,\n",
|
| 274 |
+
" id=f\"{run_id}-clip16{suffix}\",\n",
|
| 275 |
+
" resume=\"allow\",\n",
|
| 276 |
+
" )\n",
|
| 277 |
"\n",
|
| 278 |
" # work in temporary directory\n",
|
| 279 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
|
|
|
| 294 |
"\n",
|
| 295 |
" # process one batch of captions\n",
|
| 296 |
" for batch in tqdm(samples):\n",
|
| 297 |
+
" processed_prompts = (\n",
|
| 298 |
+
" [text_normalizer(x) for x in batch]\n",
|
| 299 |
+
" if model.config.normalize_text\n",
|
| 300 |
+
" else list(batch)\n",
|
| 301 |
+
" )\n",
|
| 302 |
"\n",
|
| 303 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
| 304 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
| 305 |
+
" tokenized_prompt = tokenizer(\n",
|
| 306 |
+
" processed_prompts,\n",
|
| 307 |
+
" return_tensors=\"jax\",\n",
|
| 308 |
+
" padding=\"max_length\",\n",
|
| 309 |
+
" truncation=True,\n",
|
| 310 |
+
" max_length=128,\n",
|
| 311 |
+
" ).data\n",
|
| 312 |
" tokenized_prompt = shard(tokenized_prompt)\n",
|
| 313 |
"\n",
|
| 314 |
" # generate images\n",
|
| 315 |
" images = []\n",
|
| 316 |
+
" pbar = tqdm(\n",
|
| 317 |
+
" range(num_images // jax.device_count()),\n",
|
| 318 |
+
" desc=\"Generating Images\",\n",
|
| 319 |
+
" leave=True,\n",
|
| 320 |
+
" )\n",
|
| 321 |
" for i in pbar:\n",
|
| 322 |
" key, subkey = jax.random.split(key)\n",
|
| 323 |
+
" encoded_images = p_generate(\n",
|
| 324 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params\n",
|
| 325 |
+
" )\n",
|
| 326 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
| 327 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
| 328 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
|
| 329 |
+
" (-1, 256, 256, 3)\n",
|
| 330 |
+
" )\n",
|
| 331 |
" for img in decoded_images:\n",
|
| 332 |
+
" images.append(\n",
|
| 333 |
+
" Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
|
| 334 |
+
" )\n",
|
| 335 |
"\n",
|
| 336 |
+
" def add_clip_results(results, processor, p_clip, clip_params):\n",
|
| 337 |
+
" clip_inputs = processor(\n",
|
| 338 |
+
" text=batch,\n",
|
| 339 |
+
" images=images,\n",
|
| 340 |
+
" return_tensors=\"np\",\n",
|
| 341 |
+
" padding=\"max_length\",\n",
|
| 342 |
+
" max_length=77,\n",
|
| 343 |
+
" truncation=True,\n",
|
| 344 |
+
" ).data\n",
|
| 345 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
| 346 |
+
" images_per_prompt_indices = np.asarray(\n",
|
| 347 |
+
" range(0, len(images), batch_size)\n",
|
| 348 |
+
" )\n",
|
| 349 |
+
" clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
|
| 350 |
+
" list(\n",
|
| 351 |
+
" clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
|
| 352 |
+
" for i in range(batch_size)\n",
|
| 353 |
+
" )\n",
|
| 354 |
+
" )\n",
|
| 355 |
" clip_inputs = shard(clip_inputs)\n",
|
| 356 |
" logits = p_clip(clip_inputs, clip_params)\n",
|
| 357 |
" logits = logits.reshape(-1, num_images)\n",
|
| 358 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
| 359 |
" logits = jax.device_get(logits)\n",
|
| 360 |
" # add to results table\n",
|
| 361 |
+
" for i, (idx, scores, sample) in enumerate(\n",
|
| 362 |
+
" zip(top_scores, logits, batch)\n",
|
| 363 |
+
" ):\n",
|
| 364 |
+
" if sample == padding_item:\n",
|
| 365 |
+
" continue\n",
|
| 366 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
| 367 |
+
" top_images = [\n",
|
| 368 |
+
" wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
|
| 369 |
+
" for x in idx\n",
|
| 370 |
+
" ]\n",
|
| 371 |
" results.append([sample] + top_images)\n",
|
| 372 |
+
"\n",
|
| 373 |
" # get clip scores\n",
|
| 374 |
+
" pbar.set_description(\"Calculating CLIP 16 scores\")\n",
|
| 375 |
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
| 376 |
+
"\n",
|
| 377 |
" # get clip 32 scores\n",
|
| 378 |
" if add_clip_32:\n",
|
| 379 |
+
" pbar.set_description(\"Calculating CLIP 32 scores\")\n",
|
| 380 |
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
| 381 |
"\n",
|
| 382 |
" pbar.close()\n",
|
| 383 |
"\n",
|
|
|
|
|
|
|
| 384 |
" # log results\n",
|
| 385 |
" table = wandb.Table(columns=columns, data=results16)\n",
|
| 386 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
| 387 |
" wandb.finish()\n",
|
| 388 |
+
"\n",
|
| 389 |
+
" if add_clip_32:\n",
|
| 390 |
+
" run = wandb.init(\n",
|
| 391 |
+
" job_type=\"inference\",\n",
|
| 392 |
+
" entity=\"dalle-mini\",\n",
|
| 393 |
+
" project=\"dalle-mini\",\n",
|
| 394 |
+
" config=training_config,\n",
|
| 395 |
+
" id=f\"{run_id}-clip32{suffix}\",\n",
|
| 396 |
+
" resume=\"allow\",\n",
|
| 397 |
+
" )\n",
|
| 398 |
" table = wandb.Table(columns=columns, data=results32)\n",
|
| 399 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
| 400 |
" wandb.finish()\n",
|
| 401 |
" run = None # ensure we don't log on this run"
|
| 402 |
]
|