Upload sd_token_similarity_calculator.ipynb
Browse files
sd_token_similarity_calculator.ipynb
CHANGED
|
@@ -117,7 +117,7 @@
|
|
| 117 |
"id": "Ch9puvwKH1s3",
|
| 118 |
"collapsed": true,
|
| 119 |
"cellView": "form",
|
| 120 |
-
"outputId": "
|
| 121 |
"colab": {
|
| 122 |
"base_uri": "https://localhost:8080/"
|
| 123 |
}
|
|
@@ -133,7 +133,7 @@
|
|
| 133 |
"remote: Counting objects: 100% (7/7), done.\u001b[K\n",
|
| 134 |
"remote: Compressing objects: 100% (7/7), done.\u001b[K\n",
|
| 135 |
"remote: Total 10 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
|
| 136 |
-
"Unpacking objects: 100% (10/10), 306.93 KiB |
|
| 137 |
"/content/sd_tokens\n"
|
| 138 |
]
|
| 139 |
}
|
|
@@ -345,9 +345,7 @@
|
|
| 345 |
"model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
|
| 346 |
"inputs = processor(images=image_A, return_tensors=\"pt\")\n",
|
| 347 |
"image_features = model.get_image_features(**inputs)\n",
|
| 348 |
-
"
|
| 349 |
-
"A = text_encoding_A[0]\n",
|
| 350 |
-
"_A = LA.vector_norm(A, ord=2)\n",
|
| 351 |
"prompt_A = \"the image\"\n",
|
| 352 |
"name_A = prompt_A\n",
|
| 353 |
"#-----#\n",
|
|
@@ -390,7 +388,6 @@
|
|
| 390 |
" C = token[id_C]\n",
|
| 391 |
" _C = LA.vector_norm(C, ord=2)\n",
|
| 392 |
" name_C = vocab[id_C]\n",
|
| 393 |
-
"\n",
|
| 394 |
" is_Prefix = 0\n",
|
| 395 |
"\n",
|
| 396 |
"\n",
|
|
@@ -421,10 +418,11 @@
|
|
| 421 |
" name_CB = must_start_with + ' ' + name_C.strip() + '-' + name_B.strip() + ' ' + must_end_with\n",
|
| 422 |
" #-----#\n",
|
| 423 |
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 424 |
-
"
|
| 425 |
-
"
|
| 426 |
-
"
|
| 427 |
-
"
|
|
|
|
| 428 |
" #-----#\n",
|
| 429 |
" if restrictions == \"Prefix only\":\n",
|
| 430 |
" result = sim_CB\n",
|
|
@@ -434,10 +432,11 @@
|
|
| 434 |
" #-----#\n",
|
| 435 |
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
| 436 |
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 437 |
-
"
|
| 438 |
-
"
|
| 439 |
-
"
|
| 440 |
-
"
|
|
|
|
| 441 |
" #-----#\n",
|
| 442 |
"\n",
|
| 443 |
" result = sim_CB\n",
|
|
@@ -504,8 +503,8 @@
|
|
| 504 |
"#------#\n",
|
| 505 |
"trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 506 |
"aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 507 |
-
"max_sim_ahead=max_sim_ahead
|
| 508 |
-
"max_sim_ahead=max_sim_trail
|
| 509 |
"#-----#\n",
|
| 510 |
"print(f\"place these items ahead of prompt : {aheads}\")\n",
|
| 511 |
"print(\"\")\n",
|
|
@@ -530,11 +529,14 @@
|
|
| 530 |
" if index == 3 : name = name + max_name_ahead + must_contain + max_name_trail\n",
|
| 531 |
" name = name + must_end_with\n",
|
| 532 |
" #----#\n",
|
| 533 |
-
"
|
| 534 |
-
"
|
| 535 |
-
"
|
| 536 |
-
"
|
| 537 |
-
"
|
|
|
|
|
|
|
|
|
|
| 538 |
" names[index] = name\n",
|
| 539 |
"#------#\n",
|
| 540 |
"\n",
|
|
@@ -542,12 +544,11 @@
|
|
| 542 |
"\n",
|
| 543 |
"for index in range(NUM_PERMUTATIONS):\n",
|
| 544 |
" print(names[indices[index].item()])\n",
|
| 545 |
-
" print(f'similiarity = {round(sorted[index].item()
|
| 546 |
" print('------')\n",
|
| 547 |
"\n",
|
| 548 |
"\n",
|
| 549 |
-
"\n"
|
| 550 |
-
""
|
| 551 |
],
|
| 552 |
"metadata": {
|
| 553 |
"collapsed": true,
|
|
|
|
| 117 |
"id": "Ch9puvwKH1s3",
|
| 118 |
"collapsed": true,
|
| 119 |
"cellView": "form",
|
| 120 |
+
"outputId": "8101e515-49f2-41d4-b03b-4195d56f50de",
|
| 121 |
"colab": {
|
| 122 |
"base_uri": "https://localhost:8080/"
|
| 123 |
}
|
|
|
|
| 133 |
"remote: Counting objects: 100% (7/7), done.\u001b[K\n",
|
| 134 |
"remote: Compressing objects: 100% (7/7), done.\u001b[K\n",
|
| 135 |
"remote: Total 10 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
|
| 136 |
+
"Unpacking objects: 100% (10/10), 306.93 KiB | 1.19 MiB/s, done.\n",
|
| 137 |
"/content/sd_tokens\n"
|
| 138 |
]
|
| 139 |
}
|
|
|
|
| 345 |
"model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
|
| 346 |
"inputs = processor(images=image_A, return_tensors=\"pt\")\n",
|
| 347 |
"image_features = model.get_image_features(**inputs)\n",
|
| 348 |
+
"image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)\n",
|
|
|
|
|
|
|
| 349 |
"prompt_A = \"the image\"\n",
|
| 350 |
"name_A = prompt_A\n",
|
| 351 |
"#-----#\n",
|
|
|
|
| 388 |
" C = token[id_C]\n",
|
| 389 |
" _C = LA.vector_norm(C, ord=2)\n",
|
| 390 |
" name_C = vocab[id_C]\n",
|
|
|
|
| 391 |
" is_Prefix = 0\n",
|
| 392 |
"\n",
|
| 393 |
"\n",
|
|
|
|
| 418 |
" name_CB = must_start_with + ' ' + name_C.strip() + '-' + name_B.strip() + ' ' + must_end_with\n",
|
| 419 |
" #-----#\n",
|
| 420 |
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 421 |
+
" text_features = model.get_text_features(**ids_CB)\n",
|
| 422 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 423 |
+
" logit_scale = model.logit_scale.exp()\n",
|
| 424 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 425 |
+
" sim_CB = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 426 |
" #-----#\n",
|
| 427 |
" if restrictions == \"Prefix only\":\n",
|
| 428 |
" result = sim_CB\n",
|
|
|
|
| 432 |
" #-----#\n",
|
| 433 |
" name_BC = must_start_with + name_B + name_C + must_end_with\n",
|
| 434 |
" ids_BC = processor.tokenizer(text=name_BC, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 435 |
+
" text_features = model.get_text_features(**ids_BC)\n",
|
| 436 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 437 |
+
" logit_scale = model.logit_scale.exp()\n",
|
| 438 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 439 |
+
" sim_BC = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 440 |
" #-----#\n",
|
| 441 |
"\n",
|
| 442 |
" result = sim_CB\n",
|
|
|
|
| 503 |
"#------#\n",
|
| 504 |
"trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 505 |
"aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 506 |
+
"max_sim_ahead=max_sim_ahead\n",
|
| 507 |
+
"max_sim_ahead=max_sim_trail\n",
|
| 508 |
"#-----#\n",
|
| 509 |
"print(f\"place these items ahead of prompt : {aheads}\")\n",
|
| 510 |
"print(\"\")\n",
|
|
|
|
| 529 |
" if index == 3 : name = name + max_name_ahead + must_contain + max_name_trail\n",
|
| 530 |
" name = name + must_end_with\n",
|
| 531 |
" #----#\n",
|
| 532 |
+
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 533 |
+
"\n",
|
| 534 |
+
" text_features = model.get_text_features(**ids)\n",
|
| 535 |
+
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 536 |
+
" logit_scale = model.logit_scale.exp()\n",
|
| 537 |
+
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
| 538 |
+
" sim = torch.nn.functional.cosine_similarity(text_features, image_features) * logit_scale\n",
|
| 539 |
+
" dots[index] = sim\n",
|
| 540 |
" names[index] = name\n",
|
| 541 |
"#------#\n",
|
| 542 |
"\n",
|
|
|
|
| 544 |
"\n",
|
| 545 |
"for index in range(NUM_PERMUTATIONS):\n",
|
| 546 |
" print(names[indices[index].item()])\n",
|
| 547 |
+
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|
| 548 |
" print('------')\n",
|
| 549 |
"\n",
|
| 550 |
"\n",
|
| 551 |
+
"\n"
|
|
|
|
| 552 |
],
|
| 553 |
"metadata": {
|
| 554 |
"collapsed": true,
|