Sharath33 commited on
Commit
02ac88d
·
verified ·
1 Parent(s): 637d458

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Siamese Network for Few-Shot Image Recognition
2
+
3
+ Few-shot image recognition using a Siamese Network trained on Omniglot.
4
+ Recognises new character classes from as little as a single example.
5
+
6
+ ## Results
7
+
8
+ | Configuration | Accuracy |
9
+ |----------------|----------|
10
+ | 5-way 1-shot | 95.10% |
11
+ | 5-way 5-shot | 97.07% |
12
+ | 10-way 1-shot | 90.05% |
13
+ | 10-way 5-shot | 94.83% |
14
+
15
+ Evaluated on 145 unseen test classes (never seen during training).
16
+
17
+ ## Architecture
18
+
19
+ - Backbone: ResNet-18 pretrained, final FC stripped → 512-d features
20
+ - Embedding head: Linear(512→256) → BN → ReLU → Linear(256→128) → L2 norm
21
+ - Loss: Contrastive loss with margin=1.0
22
+ - Distance: Cosine similarity on unit-sphere embeddings
23
+
24
+ ## Project Structure
25
+
26
+ siamese-few-shot/
27
+ ├── src/
28
+ │ ├── dataset.py # SiamesePairDataset + EpisodeDataset
29
+ │ ├── model.py # EmbeddingNet + SiameseNet
30
+ │ ├── loss.py # ContrastiveLoss
31
+ │ ├── train.py # Training + validation loop
32
+ │ ├── run_training.py # Main training entry point
33
+ │ ├── eval.py # N-way K-shot episodic evaluation
34
+ │ └── demo.py # Gradio demo
35
+ ├── checkpoints/
36
+ │ ├── best.pt
37
+ │ └── siamese_embedding.onnx
38
+ ├── data/
39
+ │ └── class_split.json
40
+ ├── requirements.txt
41
+ └── README.md
42
+
43
+ ## Quickstart
44
+
45
+ git clone https://huggingface.co/<your-username>/siamese-few-shot
46
+ cd siamese-few-shot
47
+ pip install -r requirements.txt
48
+
49
+ # Run Gradio demo
50
+ cd src && python demo.py
51
+
52
+ # Run episodic evaluation
53
+ cd src && python eval.py
54
+
55
+ # Retrain from scratch
56
+ cd src && python run_training.py
57
+
58
+ ## Training Details
59
+
60
+ - Dataset: Omniglot (background split, 964 classes)
61
+ - Train / val / test split: 70% / 15% / 15% of classes
62
+ - Epochs: 30
63
+ - Batch size: 32
64
+ - Optimiser: Adam lr=1e-3
65
+ - Scheduler: CosineAnnealingLR
66
+ - Augmentation: RandomCrop, HorizontalFlip, ColorJitter
67
+
68
+ ## Requirements
69
+
70
+ torch>=2.0
71
+ torchvision>=0.15
72
+ timm
73
+ gradio
74
+ onnx
75
+ onnxruntime-gpu
76
+ pillow
77
+ numpy
78
+ matplotlib
79
+ scikit-learn
80
+ tqdm
81
+ wandb
82
+
83
+ ## Demo
84
+
85
+ Upload any two handwritten character images. The model returns a
86
+ cosine similarity score and a same / different class decision.
87
+
88
+ Trained on Latin, Greek, Cyrillic, Japanese, and 25 other alphabets
89
+ via the Omniglot dataset. Also tested on Indian script characters
90
+ (Tamil, Hindi, Telugu, Kannada, Bengali, Malayalam, Gujarati, Punjabi).
checkpoints/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893f176f41f335d1b59d30e6d18ed197878f525e9eb4a1b56aee720a02df9b79
3
+ size 136221898
checkpoints/siamese_embedding.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0ed0762b57ebf00fa1f780f6a748c0142db71bbb0362db25abbc8033f1d8031
3
+ size 45362272
data/class_split.json ADDED
@@ -0,0 +1,972 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": [
3
+ 771,
4
+ 149,
5
+ 586,
6
+ 502,
7
+ 49,
8
+ 272,
9
+ 752,
10
+ 927,
11
+ 877,
12
+ 787,
13
+ 404,
14
+ 56,
15
+ 483,
16
+ 804,
17
+ 489,
18
+ 436,
19
+ 690,
20
+ 85,
21
+ 150,
22
+ 33,
23
+ 460,
24
+ 533,
25
+ 608,
26
+ 318,
27
+ 289,
28
+ 185,
29
+ 38,
30
+ 810,
31
+ 564,
32
+ 53,
33
+ 861,
34
+ 329,
35
+ 444,
36
+ 587,
37
+ 737,
38
+ 16,
39
+ 202,
40
+ 320,
41
+ 365,
42
+ 536,
43
+ 188,
44
+ 766,
45
+ 178,
46
+ 636,
47
+ 349,
48
+ 184,
49
+ 247,
50
+ 951,
51
+ 198,
52
+ 789,
53
+ 18,
54
+ 708,
55
+ 858,
56
+ 588,
57
+ 380,
58
+ 936,
59
+ 742,
60
+ 279,
61
+ 385,
62
+ 177,
63
+ 477,
64
+ 673,
65
+ 262,
66
+ 397,
67
+ 683,
68
+ 241,
69
+ 717,
70
+ 264,
71
+ 774,
72
+ 157,
73
+ 697,
74
+ 907,
75
+ 848,
76
+ 782,
77
+ 808,
78
+ 236,
79
+ 647,
80
+ 189,
81
+ 760,
82
+ 806,
83
+ 856,
84
+ 611,
85
+ 443,
86
+ 509,
87
+ 702,
88
+ 948,
89
+ 903,
90
+ 835,
91
+ 293,
92
+ 560,
93
+ 360,
94
+ 725,
95
+ 316,
96
+ 201,
97
+ 902,
98
+ 884,
99
+ 461,
100
+ 28,
101
+ 779,
102
+ 788,
103
+ 336,
104
+ 290,
105
+ 334,
106
+ 793,
107
+ 450,
108
+ 659,
109
+ 638,
110
+ 641,
111
+ 807,
112
+ 664,
113
+ 924,
114
+ 381,
115
+ 844,
116
+ 456,
117
+ 504,
118
+ 585,
119
+ 24,
120
+ 327,
121
+ 1,
122
+ 625,
123
+ 572,
124
+ 904,
125
+ 843,
126
+ 76,
127
+ 561,
128
+ 515,
129
+ 909,
130
+ 485,
131
+ 845,
132
+ 887,
133
+ 364,
134
+ 811,
135
+ 480,
136
+ 308,
137
+ 501,
138
+ 700,
139
+ 231,
140
+ 311,
141
+ 463,
142
+ 401,
143
+ 422,
144
+ 457,
145
+ 174,
146
+ 609,
147
+ 868,
148
+ 466,
149
+ 435,
150
+ 715,
151
+ 559,
152
+ 606,
153
+ 510,
154
+ 651,
155
+ 720,
156
+ 724,
157
+ 785,
158
+ 92,
159
+ 110,
160
+ 482,
161
+ 424,
162
+ 227,
163
+ 915,
164
+ 741,
165
+ 129,
166
+ 554,
167
+ 325,
168
+ 744,
169
+ 494,
170
+ 88,
171
+ 430,
172
+ 770,
173
+ 945,
174
+ 722,
175
+ 631,
176
+ 26,
177
+ 326,
178
+ 426,
179
+ 601,
180
+ 447,
181
+ 50,
182
+ 355,
183
+ 680,
184
+ 173,
185
+ 147,
186
+ 657,
187
+ 622,
188
+ 148,
189
+ 939,
190
+ 79,
191
+ 358,
192
+ 200,
193
+ 183,
194
+ 229,
195
+ 384,
196
+ 524,
197
+ 544,
198
+ 256,
199
+ 716,
200
+ 119,
201
+ 475,
202
+ 251,
203
+ 29,
204
+ 595,
205
+ 121,
206
+ 287,
207
+ 61,
208
+ 301,
209
+ 729,
210
+ 5,
211
+ 617,
212
+ 547,
213
+ 894,
214
+ 193,
215
+ 434,
216
+ 727,
217
+ 748,
218
+ 17,
219
+ 769,
220
+ 842,
221
+ 391,
222
+ 575,
223
+ 138,
224
+ 191,
225
+ 4,
226
+ 299,
227
+ 20,
228
+ 213,
229
+ 64,
230
+ 719,
231
+ 959,
232
+ 192,
233
+ 886,
234
+ 857,
235
+ 302,
236
+ 425,
237
+ 369,
238
+ 282,
239
+ 474,
240
+ 409,
241
+ 940,
242
+ 753,
243
+ 66,
244
+ 353,
245
+ 885,
246
+ 876,
247
+ 221,
248
+ 34,
249
+ 648,
250
+ 421,
251
+ 710,
252
+ 925,
253
+ 879,
254
+ 144,
255
+ 503,
256
+ 419,
257
+ 232,
258
+ 13,
259
+ 398,
260
+ 239,
261
+ 84,
262
+ 898,
263
+ 172,
264
+ 670,
265
+ 454,
266
+ 776,
267
+ 813,
268
+ 298,
269
+ 537,
270
+ 249,
271
+ 333,
272
+ 908,
273
+ 634,
274
+ 530,
275
+ 180,
276
+ 730,
277
+ 866,
278
+ 795,
279
+ 10,
280
+ 199,
281
+ 396,
282
+ 548,
283
+ 31,
284
+ 889,
285
+ 801,
286
+ 277,
287
+ 414,
288
+ 577,
289
+ 393,
290
+ 784,
291
+ 372,
292
+ 78,
293
+ 499,
294
+ 557,
295
+ 531,
296
+ 571,
297
+ 41,
298
+ 851,
299
+ 800,
300
+ 361,
301
+ 139,
302
+ 491,
303
+ 392,
304
+ 39,
305
+ 705,
306
+ 140,
307
+ 929,
308
+ 684,
309
+ 427,
310
+ 809,
311
+ 126,
312
+ 132,
313
+ 713,
314
+ 350,
315
+ 493,
316
+ 957,
317
+ 386,
318
+ 743,
319
+ 258,
320
+ 313,
321
+ 374,
322
+ 162,
323
+ 304,
324
+ 286,
325
+ 819,
326
+ 330,
327
+ 68,
328
+ 458,
329
+ 620,
330
+ 495,
331
+ 780,
332
+ 961,
333
+ 237,
334
+ 265,
335
+ 90,
336
+ 803,
337
+ 830,
338
+ 54,
339
+ 635,
340
+ 528,
341
+ 354,
342
+ 124,
343
+ 660,
344
+ 921,
345
+ 862,
346
+ 679,
347
+ 294,
348
+ 831,
349
+ 212,
350
+ 883,
351
+ 37,
352
+ 312,
353
+ 694,
354
+ 125,
355
+ 23,
356
+ 472,
357
+ 596,
358
+ 101,
359
+ 115,
360
+ 406,
361
+ 914,
362
+ 120,
363
+ 47,
364
+ 171,
365
+ 317,
366
+ 324,
367
+ 411,
368
+ 263,
369
+ 145,
370
+ 43,
371
+ 471,
372
+ 86,
373
+ 513,
374
+ 534,
375
+ 896,
376
+ 345,
377
+ 375,
378
+ 226,
379
+ 878,
380
+ 164,
381
+ 335,
382
+ 814,
383
+ 881,
384
+ 417,
385
+ 107,
386
+ 839,
387
+ 155,
388
+ 428,
389
+ 280,
390
+ 207,
391
+ 154,
392
+ 568,
393
+ 523,
394
+ 949,
395
+ 383,
396
+ 711,
397
+ 626,
398
+ 496,
399
+ 820,
400
+ 481,
401
+ 712,
402
+ 542,
403
+ 527,
404
+ 356,
405
+ 206,
406
+ 872,
407
+ 600,
408
+ 539,
409
+ 614,
410
+ 259,
411
+ 153,
412
+ 423,
413
+ 63,
414
+ 339,
415
+ 695,
416
+ 723,
417
+ 455,
418
+ 340,
419
+ 632,
420
+ 916,
421
+ 186,
422
+ 859,
423
+ 343,
424
+ 506,
425
+ 869,
426
+ 368,
427
+ 818,
428
+ 275,
429
+ 667,
430
+ 266,
431
+ 955,
432
+ 734,
433
+ 516,
434
+ 22,
435
+ 525,
436
+ 97,
437
+ 295,
438
+ 197,
439
+ 652,
440
+ 261,
441
+ 310,
442
+ 943,
443
+ 160,
444
+ 402,
445
+ 522,
446
+ 176,
447
+ 624,
448
+ 816,
449
+ 490,
450
+ 487,
451
+ 573,
452
+ 297,
453
+ 507,
454
+ 538,
455
+ 449,
456
+ 761,
457
+ 14,
458
+ 732,
459
+ 836,
460
+ 431,
461
+ 666,
462
+ 581,
463
+ 260,
464
+ 328,
465
+ 792,
466
+ 467,
467
+ 395,
468
+ 35,
469
+ 628,
470
+ 822,
471
+ 637,
472
+ 204,
473
+ 98,
474
+ 337,
475
+ 508,
476
+ 532,
477
+ 116,
478
+ 615,
479
+ 407,
480
+ 678,
481
+ 960,
482
+ 179,
483
+ 707,
484
+ 901,
485
+ 418,
486
+ 579,
487
+ 113,
488
+ 605,
489
+ 439,
490
+ 750,
491
+ 446,
492
+ 195,
493
+ 672,
494
+ 359,
495
+ 403,
496
+ 880,
497
+ 136,
498
+ 899,
499
+ 415,
500
+ 376,
501
+ 442,
502
+ 342,
503
+ 639,
504
+ 210,
505
+ 476,
506
+ 400,
507
+ 644,
508
+ 850,
509
+ 377,
510
+ 91,
511
+ 12,
512
+ 211,
513
+ 451,
514
+ 181,
515
+ 870,
516
+ 645,
517
+ 627,
518
+ 893,
519
+ 676,
520
+ 105,
521
+ 763,
522
+ 689,
523
+ 366,
524
+ 599,
525
+ 871,
526
+ 315,
527
+ 42,
528
+ 130,
529
+ 440,
530
+ 151,
531
+ 629,
532
+ 36,
533
+ 152,
534
+ 653,
535
+ 749,
536
+ 582,
537
+ 437,
538
+ 452,
539
+ 757,
540
+ 268,
541
+ 133,
542
+ 341,
543
+ 805,
544
+ 45,
545
+ 283,
546
+ 630,
547
+ 912,
548
+ 52,
549
+ 790,
550
+ 751,
551
+ 941,
552
+ 933,
553
+ 208,
554
+ 351,
555
+ 215,
556
+ 288,
557
+ 278,
558
+ 917,
559
+ 109,
560
+ 118,
561
+ 905,
562
+ 137,
563
+ 106,
564
+ 306,
565
+ 8,
566
+ 891,
567
+ 309,
568
+ 556,
569
+ 864,
570
+ 612,
571
+ 291,
572
+ 378,
573
+ 855,
574
+ 607,
575
+ 357,
576
+ 688,
577
+ 589,
578
+ 518,
579
+ 765,
580
+ 550,
581
+ 75,
582
+ 102,
583
+ 576,
584
+ 923,
585
+ 9,
586
+ 74,
587
+ 594,
588
+ 468,
589
+ 307,
590
+ 134,
591
+ 827,
592
+ 892,
593
+ 244,
594
+ 619,
595
+ 209,
596
+ 267,
597
+ 838,
598
+ 535,
599
+ 578,
600
+ 932,
601
+ 83,
602
+ 40,
603
+ 592,
604
+ 821,
605
+ 583,
606
+ 122,
607
+ 413,
608
+ 240,
609
+ 863,
610
+ 772,
611
+ 190,
612
+ 82,
613
+ 520,
614
+ 906,
615
+ 775,
616
+ 755,
617
+ 514,
618
+ 488,
619
+ 815,
620
+ 649,
621
+ 58,
622
+ 321,
623
+ 668,
624
+ 555,
625
+ 593,
626
+ 745,
627
+ 222,
628
+ 303,
629
+ 662,
630
+ 158,
631
+ 498,
632
+ 569,
633
+ 832,
634
+ 292,
635
+ 465,
636
+ 731,
637
+ 399,
638
+ 2,
639
+ 852,
640
+ 168,
641
+ 846,
642
+ 837,
643
+ 218,
644
+ 492,
645
+ 691,
646
+ 682,
647
+ 170,
648
+ 242,
649
+ 944,
650
+ 15,
651
+ 553,
652
+ 51,
653
+ 829,
654
+ 563,
655
+ 453,
656
+ 77,
657
+ 255,
658
+ 853,
659
+ 693,
660
+ 187,
661
+ 798,
662
+ 143,
663
+ 930,
664
+ 783,
665
+ 681,
666
+ 888,
667
+ 254,
668
+ 111,
669
+ 347,
670
+ 412,
671
+ 62,
672
+ 100,
673
+ 661,
674
+ 669,
675
+ 55,
676
+ 478
677
+ ],
678
+ "val": [
679
+ 420,
680
+ 728,
681
+ 362,
682
+ 441,
683
+ 674,
684
+ 910,
685
+ 96,
686
+ 194,
687
+ 416,
688
+ 953,
689
+ 248,
690
+ 484,
691
+ 590,
692
+ 584,
693
+ 135,
694
+ 726,
695
+ 219,
696
+ 740,
697
+ 685,
698
+ 285,
699
+ 243,
700
+ 526,
701
+ 703,
702
+ 338,
703
+ 840,
704
+ 69,
705
+ 841,
706
+ 60,
707
+ 646,
708
+ 72,
709
+ 7,
710
+ 931,
711
+ 709,
712
+ 235,
713
+ 567,
714
+ 602,
715
+ 21,
716
+ 346,
717
+ 796,
718
+ 230,
719
+ 253,
720
+ 123,
721
+ 462,
722
+ 529,
723
+ 448,
724
+ 952,
725
+ 935,
726
+ 687,
727
+ 812,
728
+ 319,
729
+ 205,
730
+ 706,
731
+ 756,
732
+ 834,
733
+ 433,
734
+ 621,
735
+ 540,
736
+ 823,
737
+ 169,
738
+ 562,
739
+ 486,
740
+ 675,
741
+ 131,
742
+ 128,
743
+ 545,
744
+ 70,
745
+ 497,
746
+ 87,
747
+ 897,
748
+ 833,
749
+ 246,
750
+ 59,
751
+ 245,
752
+ 314,
753
+ 371,
754
+ 778,
755
+ 19,
756
+ 500,
757
+ 331,
758
+ 613,
759
+ 0,
760
+ 543,
761
+ 552,
762
+ 165,
763
+ 382,
764
+ 802,
765
+ 937,
766
+ 860,
767
+ 767,
768
+ 963,
769
+ 305,
770
+ 640,
771
+ 108,
772
+ 519,
773
+ 182,
774
+ 512,
775
+ 817,
776
+ 736,
777
+ 739,
778
+ 3,
779
+ 874,
780
+ 161,
781
+ 445,
782
+ 895,
783
+ 962,
784
+ 919,
785
+ 656,
786
+ 865,
787
+ 768,
788
+ 900,
789
+ 698,
790
+ 117,
791
+ 738,
792
+ 799,
793
+ 11,
794
+ 566,
795
+ 257,
796
+ 541,
797
+ 479,
798
+ 797,
799
+ 390,
800
+ 394,
801
+ 65,
802
+ 610,
803
+ 920,
804
+ 696,
805
+ 922,
806
+ 642,
807
+ 156,
808
+ 112,
809
+ 48,
810
+ 773,
811
+ 93,
812
+ 505,
813
+ 521,
814
+ 141,
815
+ 847,
816
+ 926,
817
+ 408,
818
+ 597,
819
+ 438,
820
+ 598,
821
+ 764,
822
+ 269,
823
+ 551
824
+ ],
825
+ "test": [
826
+ 938,
827
+ 762,
828
+ 252,
829
+ 956,
830
+ 271,
831
+ 146,
832
+ 469,
833
+ 658,
834
+ 405,
835
+ 511,
836
+ 671,
837
+ 217,
838
+ 322,
839
+ 735,
840
+ 580,
841
+ 216,
842
+ 67,
843
+ 274,
844
+ 410,
845
+ 323,
846
+ 824,
847
+ 946,
848
+ 234,
849
+ 57,
850
+ 794,
851
+ 786,
852
+ 332,
853
+ 701,
854
+ 224,
855
+ 570,
856
+ 704,
857
+ 655,
858
+ 276,
859
+ 388,
860
+ 473,
861
+ 167,
862
+ 958,
863
+ 746,
864
+ 546,
865
+ 175,
866
+ 873,
867
+ 623,
868
+ 73,
869
+ 663,
870
+ 699,
871
+ 934,
872
+ 273,
873
+ 686,
874
+ 214,
875
+ 363,
876
+ 379,
877
+ 166,
878
+ 373,
879
+ 854,
880
+ 650,
881
+ 464,
882
+ 918,
883
+ 911,
884
+ 103,
885
+ 942,
886
+ 875,
887
+ 81,
888
+ 296,
889
+ 791,
890
+ 233,
891
+ 677,
892
+ 46,
893
+ 71,
894
+ 721,
895
+ 196,
896
+ 591,
897
+ 370,
898
+ 882,
899
+ 633,
900
+ 643,
901
+ 849,
902
+ 300,
903
+ 565,
904
+ 80,
905
+ 387,
906
+ 127,
907
+ 549,
908
+ 470,
909
+ 747,
910
+ 44,
911
+ 826,
912
+ 270,
913
+ 618,
914
+ 352,
915
+ 867,
916
+ 367,
917
+ 99,
918
+ 389,
919
+ 94,
920
+ 954,
921
+ 344,
922
+ 781,
923
+ 220,
924
+ 159,
925
+ 928,
926
+ 348,
927
+ 947,
928
+ 714,
929
+ 163,
930
+ 825,
931
+ 777,
932
+ 6,
933
+ 890,
934
+ 828,
935
+ 284,
936
+ 603,
937
+ 459,
938
+ 225,
939
+ 429,
940
+ 950,
941
+ 718,
942
+ 665,
943
+ 733,
944
+ 203,
945
+ 574,
946
+ 27,
947
+ 616,
948
+ 517,
949
+ 238,
950
+ 223,
951
+ 95,
952
+ 30,
953
+ 32,
954
+ 432,
955
+ 604,
956
+ 89,
957
+ 558,
958
+ 913,
959
+ 758,
960
+ 692,
961
+ 104,
962
+ 754,
963
+ 142,
964
+ 228,
965
+ 250,
966
+ 281,
967
+ 759,
968
+ 25,
969
+ 114,
970
+ 654
971
+ ]
972
+ }
env_setup/backbone_sanity_check.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.models as models
3
+
4
+ backbone = models.resnet18(pretrained=True)
5
+ backbone.fc = torch.nn.Identity() # strip final FC → 512-d output
6
+ backbone.eval()
7
+
8
+ dummy = torch.randn(1, 3, 224, 224)
9
+ with torch.no_grad():
10
+ emb = backbone(dummy)
11
+
12
+ print(f"Embedding shape : {emb.shape}") # expect torch.Size([1, 512])
env_setup/gpu_sanity_check.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ print(f"PyTorch version : {torch.__version__}")
4
+ print(f"CUDA available : {torch.cuda.is_available()}")
5
+ print(f"GPU : {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None — using CPU'}")
6
+
7
+ # Quick tensor op on GPU
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ x = torch.randn(3, 224, 224).to(device)
10
+ print(f"Test tensor on : {x.device} — shape {x.shape}")
env_setup/setup.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python -m venv venv
2
+ source venv/bin/activate # Linux / Mac
3
+ # venv\Scripts\activate # Windows
4
+
5
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
6
+ pip install timm pillow numpy matplotlib scikit-learn tqdm wandb gradio
7
+ pip freeze > requirements.txt
logs/sample_grid.png ADDED
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ torchvision>=0.15
3
+ timm
4
+ gradio
5
+ onnx
6
+ onnxruntime-gpu
7
+ pillow
8
+ numpy
9
+ matplotlib
10
+ scikit-learn
11
+ tqdm
12
+ wandb
src/dataset.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.datasets import Omniglot
2
+ from torchvision import transforms
3
+ import matplotlib.pyplot as plt
4
+ import random, os
5
+ import json, random
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from PIL import Image
9
+
10
+
11
+ class SiamesePairDataset(Dataset):
12
+ def __init__(self, dataset, allowed_classes, transform=None, num_pairs=10000):
13
+ self.transform = transform
14
+ self.num_pairs = num_pairs
15
+
16
+ # Group image indices by class
17
+ self.class_to_indices = {}
18
+ for idx, (_, label) in enumerate(dataset):
19
+ if label not in allowed_classes:
20
+ continue
21
+ self.class_to_indices.setdefault(label, []).append(idx)
22
+
23
+ self.classes = list(self.class_to_indices.keys())
24
+ self.dataset = dataset
25
+
26
+ def __len__(self):
27
+ return self.num_pairs
28
+
29
+ def __getitem__(self, _):
30
+ is_positive = random.random() > 0.5 # 50/50 split
31
+
32
+ if is_positive:
33
+ cls = random.choice(self.classes)
34
+ i1, i2 = random.sample(self.class_to_indices[cls], 2)
35
+ else:
36
+ cls1, cls2 = random.sample(self.classes, 2)
37
+ i1 = random.choice(self.class_to_indices[cls1])
38
+ i2 = random.choice(self.class_to_indices[cls2])
39
+
40
+ img1, _ = self.dataset[i1]
41
+ img2, _ = self.dataset[i2]
42
+
43
+ if self.transform:
44
+ img1 = self.transform(img1)
45
+ img2 = self.transform(img2)
46
+
47
+ label = torch.tensor(1.0 if is_positive else 0.0)
48
+ return img1, img2, label
49
+
50
+ def dl_data():
51
+ basic = transforms.ToTensor()
52
+
53
+ bg = Omniglot(root=root, background=True, download=True, transform=basic)
54
+ eval = Omniglot(root=root, background=False, download=True, transform=basic)
55
+
56
+ print(f"Background split : {len(bg)} images")
57
+ print(f"Evaluation split : {len(eval)} images")
58
+
59
+ # Quick grid of sample images
60
+ fig, axes = plt.subplots(2, 10, figsize=(16, 4))
61
+ for i, ax in enumerate(axes.flat):
62
+ img, label = bg[i * 20]
63
+ ax.imshow(img.squeeze(), cmap="gray")
64
+ ax.axis("off")
65
+ plt.tight_layout()
66
+ plt.savefig("../logs/sample_grid.png", dpi=100)
67
+ plt.show()
68
+ # Split test, train and eval
69
+ class_split(bg)
70
+
71
+ def class_split(bg):
72
+
73
+ all_classes = list(set([label for _, label in bg]))
74
+ random.seed(42)
75
+ random.shuffle(all_classes)
76
+
77
+ n = len(all_classes)
78
+ train_classes = all_classes[:int(n * 0.7)]
79
+ val_classes = all_classes[int(n * 0.7):int(n * 0.85)]
80
+ test_classes = all_classes[int(n * 0.85):] # NEVER touch until Day 5
81
+
82
+ split = {"train": train_classes, "val": val_classes, "test": test_classes}
83
+ with open(os.path.join(root, "class_split.json"), "w") as f:
84
+ json.dump(split, f, indent=4)
85
+
86
+ print(f"Train: {len(train_classes)} | Val: {len(val_classes)} | Test: {len(test_classes)}")
87
+
88
+ def validate_dataloader():
89
+ import json
90
+ from torch.utils.data import DataLoader
91
+ bg = Omniglot(root=root, background=True, download=True, transform=None)
92
+ with open(os.path.join(root, "class_split.json")) as f:
93
+ split = json.load(f)
94
+
95
+ train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform, num_pairs=10000)
96
+ val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform, num_pairs=2000)
97
+
98
+ train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
99
+ val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
100
+
101
+ # Sanity check
102
+ img1, img2, labels = next(iter(train_loader))
103
+ print(f"img1 shape : {img1.shape}") # [32, 1, 105, 105]
104
+ print(f"img2 shape : {img2.shape}") # [32, 1, 105, 105]
105
+ print(f"labels : {labels[:8]}")
106
+ print(f"Positive % : {labels.mean().item()*100:.1f}%") # should be ~50%
107
+ assert img1.shape == img2.shape == torch.Size([32, 1, 105, 105])
108
+ print("All assertions passed — DataLoader is ready")
109
+
110
+
111
+ if __name__ == "__main__":
112
+ root = "../data"
113
+ if os.listdir(root) == []:
114
+ dl_data()
115
+
116
+
117
+ MEAN, STD = [0.9220], [0.2256] # Omniglot stats (grayscale)
118
+
119
+ train_transform = transforms.Compose([
120
+ transforms.Grayscale(),
121
+ transforms.Resize((105, 105)),
122
+ transforms.RandomCrop(105, padding=8),
123
+ transforms.RandomHorizontalFlip(),
124
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
125
+ transforms.ToTensor(),
126
+ transforms.Normalize(MEAN, STD),
127
+ ])
128
+
129
+ eval_transform = transforms.Compose([
130
+ transforms.Grayscale(),
131
+ transforms.Resize((105, 105)),
132
+ transforms.ToTensor(),
133
+ transforms.Normalize(MEAN, STD),
134
+ ])
135
+ validate_dataloader()
src/demo.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from model import SiameseNet
7
+
8
+ # ── Load model ────────────────────────────────────────────────
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model = SiameseNet(embedding_dim=128).to(device)
11
+ ckpt = torch.load("../checkpoints/best.pt", map_location=device)
12
+ model.load_state_dict(ckpt["model_state"])
13
+ model.eval()
14
+
15
+ # ── Transform ─────────────────────────────────────────────────
16
+ transform = transforms.Compose([
17
+ transforms.Grayscale(),
18
+ transforms.Resize((105, 105)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize([0.9220], [0.2256]),
21
+ ])
22
+
23
+ def preprocess(img: Image.Image) -> torch.Tensor:
24
+ return transform(img).unsqueeze(0).to(device) # [1, 1, 105, 105]
25
+
26
+ # ── Inference ─────────────────────────────────────────────────
27
+ def compare_images(img1: Image.Image, img2: Image.Image):
28
+ with torch.no_grad():
29
+ emb1 = model.get_embedding(preprocess(img1))
30
+ emb2 = model.get_embedding(preprocess(img2))
31
+ similarity = F.cosine_similarity(emb1, emb2).item()
32
+
33
+ match = similarity > 0.5
34
+ label = "Same class" if match else "Different class"
35
+ conf = f"{similarity * 100:.1f}%"
36
+ colour = "green" if match else "red"
37
+
38
+ result = f"""
39
+ <div style='text-align:center; padding: 16px;'>
40
+ <div style='font-size: 28px; font-weight: 600; color: {colour};'>{label}</div>
41
+ <div style='font-size: 16px; color: gray; margin-top: 8px;'>
42
+ Cosine similarity: <strong>{conf}</strong>
43
+ </div>
44
+ </div>
45
+ """
46
+ return result, round(similarity, 4)
47
+
48
+ # ── UI ────────────────────────────────────────────────────────
49
+ with gr.Blocks(title="Siamese Few-Shot Recognition") as demo:
50
+ gr.Markdown("## Siamese Network — Few-Shot Image Similarity")
51
+ gr.Markdown("Upload two images. The model will tell you if they belong to the same class.")
52
+
53
+ with gr.Row():
54
+ img1 = gr.Image(type="pil", label="Image 1")
55
+ img2 = gr.Image(type="pil", label="Image 2")
56
+
57
+ btn = gr.Button("Compare", variant="primary")
58
+
59
+ result_html = gr.HTML()
60
+ result_score = gr.Number(label="Raw similarity score")
61
+
62
+ btn.click(fn=compare_images, inputs=[img1, img2],
63
+ outputs=[result_html, result_score])
64
+
65
+ if __name__ == "__main__":
66
+ demo.launch(share=True) # share=True gives a public URL
src/eval.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision.datasets import Omniglot
4
+ from torchvision import transforms
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import json, os, random
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ from model import SiameseNet
11
+
12
+ # ── Episode Dataset ───────────────────────────────────────────
13
+ class EpisodeDataset(Dataset):
14
+ """
15
+ Each item is one N-way K-shot episode:
16
+ - N classes, K support images each → support set
17
+ - N classes, 1 query image each → query set
18
+ Returns support embeddings + query image + correct label
19
+ """
20
+ def __init__(self, dataset, allowed_classes, transform, n_way=5, k_shot=1, n_episodes=600):
21
+ self.transform = transform
22
+ self.n_way = n_way
23
+ self.k_shot = k_shot
24
+ self.n_episodes = n_episodes
25
+ self.dataset = dataset
26
+
27
+ self.class_to_indices = {}
28
+ for idx, (_, label) in enumerate(dataset):
29
+ if label not in allowed_classes:
30
+ continue
31
+ self.class_to_indices.setdefault(label, []).append(idx)
32
+
33
+ # Only keep classes with enough samples for K support + 1 query
34
+ self.classes = [c for c, idxs in self.class_to_indices.items()
35
+ if len(idxs) >= k_shot + 1]
36
+
37
+ def __len__(self):
38
+ return self.n_episodes
39
+
40
+ def __getitem__(self, _):
41
+ # Sample N classes for this episode
42
+ episode_classes = random.sample(self.classes, self.n_way)
43
+
44
+ support_imgs, query_imgs, query_labels = [], [], []
45
+
46
+ for label_idx, cls in enumerate(episode_classes):
47
+ indices = random.sample(self.class_to_indices[cls], self.k_shot + 1)
48
+ support_indices = indices[:self.k_shot]
49
+ query_index = indices[self.k_shot]
50
+
51
+ for i in support_indices:
52
+ img, _ = self.dataset[i]
53
+ support_imgs.append(self.transform(img))
54
+
55
+ img, _ = self.dataset[query_index]
56
+ query_imgs.append(self.transform(img))
57
+ query_labels.append(label_idx)
58
+
59
+ # support: [N*K, C, H, W] | query: [N, C, H, W]
60
+ support = torch.stack(support_imgs)
61
+ query = torch.stack(query_imgs)
62
+ labels = torch.tensor(query_labels)
63
+ return support, query, labels
64
+
65
+
66
+ # ── Evaluation function ───────────────────────────────────────
67
+ @torch.no_grad()
68
+ def evaluate_episodes(model, episode_ds, device, n_way, k_shot):
69
+ model.eval()
70
+ correct, total = 0, 0
71
+
72
+ loader = DataLoader(episode_ds, batch_size=1, shuffle=False, num_workers=2)
73
+
74
+ for support, query, labels in tqdm(loader, desc=f"{n_way}-way {k_shot}-shot"):
75
+ # Remove batch dim (batch_size=1)
76
+ support = support.squeeze(0).to(device) # [N*K, C, H, W]
77
+ query = query.squeeze(0).to(device) # [N, C, H, W]
78
+ labels = labels.squeeze(0).to(device) # [N]
79
+
80
+ # Get embeddings
81
+ support_emb = model.get_embedding(support) # [N*K, 128]
82
+ query_emb = model.get_embedding(query) # [N, 128]
83
+
84
+ # Compute class prototypes (mean of K support embeddings per class)
85
+ support_emb = support_emb.view(n_way, k_shot, -1).mean(dim=1) # [N, 128]
86
+
87
+ # Cosine similarity: each query vs each class prototype
88
+ sim = F.cosine_similarity(
89
+ query_emb.unsqueeze(1), # [N, 1, 128]
90
+ support_emb.unsqueeze(0), # [1, N, 128]
91
+ dim=2 # → [N, N]
92
+ )
93
+
94
+ preds = sim.argmax(dim=1) # [N]
95
+ correct += (preds == labels).sum().item()
96
+ total += labels.size(0)
97
+
98
+ accuracy = correct / total
99
+ return accuracy
100
+
101
+
102
+ # ── Run all eval configurations ───────────────────────────────
103
+ def run_eval(checkpoint_path, data_root, split_path):
104
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
105
+
106
+ # Load model
107
+ model = SiameseNet(embedding_dim=128).to(device)
108
+ ckpt = torch.load(checkpoint_path, map_location=device)
109
+ model.load_state_dict(ckpt["model_state"])
110
+ print(f"Loaded checkpoint from epoch {ckpt['epoch']}")
111
+
112
+ eval_transform = transforms.Compose([
113
+ transforms.Grayscale(),
114
+ transforms.Resize((105, 105)),
115
+ transforms.ToTensor(),
116
+ transforms.Normalize([0.9220], [0.2256]),
117
+ ])
118
+
119
+ bg = Omniglot(root=data_root, background=True, download=False, transform=None)
120
+
121
+ with open(split_path) as f:
122
+ test_classes = json.load(f)["test"]
123
+
124
+ print(f"Evaluating on {len(test_classes)} unseen test classes\n")
125
+
126
+ results = {}
127
+ for n_way in [5, 10]:
128
+ for k_shot in [1, 5]:
129
+ ep_ds = EpisodeDataset(
130
+ bg, test_classes, eval_transform,
131
+ n_way=n_way, k_shot=k_shot, n_episodes=600
132
+ )
133
+ acc = evaluate_episodes(model, ep_ds, device, n_way, k_shot)
134
+ key = f"{n_way}-way {k_shot}-shot"
135
+ results[key] = acc
136
+ print(f" {key:18s} → {acc*100:.2f}%")
137
+
138
+ return results
139
+
140
+
141
+ if __name__ == "__main__":
142
+ results = run_eval(
143
+ checkpoint_path = "../checkpoints/best.pt",
144
+ data_root = "../data",
145
+ split_path = "../data/class_split.json",
146
+ )
src/export_onnx.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import SiameseNet
3
+
4
+ device = torch.device("cpu") # export on CPU for portability
5
+ model = SiameseNet(embedding_dim=128)
6
+ ckpt = torch.load("../checkpoints/best.pt", map_location=device)
7
+ model.load_state_dict(ckpt["model_state"])
8
+ model.eval()
9
+
10
+ # Export the embedding net only (that's all you need at inference)
11
+ dummy = torch.randn(1, 1, 105, 105)
12
+
13
+ torch.onnx.export(
14
+ model.embedding_net,
15
+ dummy,
16
+ "../checkpoints/siamese_embedding.onnx",
17
+ input_names = ["image"],
18
+ output_names = ["embedding"],
19
+ dynamic_axes = {"image": {0: "batch"}, "embedding": {0: "batch"}},
20
+ opset_version = 17,
21
+ )
22
+ print("ONNX model exported → checkpoints/siamese_embedding.onnx")
23
+
24
+ # ── Verify with onnxruntime ───────────────────────────────────
25
+ import onnxruntime as ort
26
+ import numpy as np
27
+
28
+ sess = ort.InferenceSession("../checkpoints/siamese_embedding.onnx")
29
+ out = sess.run(None, {"image": dummy.numpy()})
30
+ print(f"ONNX output shape : {out[0].shape}") # (1, 128)
31
+ print(f"ONNX output norm : {np.linalg.norm(out[0]):.4f}") # ~1.0
32
+ print("ONNX verification passed")
src/fp_sanity_check.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick sanity check — Run this before running a single line of training loop code:
2
+ import torch
3
+ from model import SiameseNet
4
+ from loss import ContrastiveLoss
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ model = SiameseNet(embedding_dim=128).to(device)
9
+ criterion = ContrastiveLoss(margin=1.0)
10
+
11
+ # Fake a batch matching your DataLoader output shape
12
+ img1 = torch.randn(32, 1, 105, 105).to(device)
13
+ img2 = torch.randn(32, 1, 105, 105).to(device)
14
+ labels = torch.randint(0, 2, (32,)).float().to(device)
15
+
16
+ emb1, emb2 = model(img1, img2)
17
+ loss, dist = criterion(emb1, emb2, labels)
18
+
19
+ print(f"emb1 shape : {emb1.shape}") # [32, 128]
20
+ print(f"emb2 shape : {emb2.shape}") # [32, 128]
21
+ print(f"emb1 norm : {emb1.norm(dim=1).mean():.4f}") # should be ~1.0
22
+ print(f"loss : {loss.item():.4f}")
23
+ print(f"dist range : {dist.min():.3f} – {dist.max():.3f}")
24
+ print("Sanity check passed")
src/loss.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/loss.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class ContrastiveLoss(nn.Module):
7
+ def __init__(self, margin=1.0):
8
+ super().__init__()
9
+ self.margin = margin
10
+
11
+ def forward(self, emb1, emb2, label):
12
+ # Euclidean distance between embedding pairs
13
+ dist = F.pairwise_distance(emb1, emb2)
14
+
15
+ # label=1 → same class (pull together), label=0 → different class (push apart)
16
+ loss = label * dist.pow(2) + \
17
+ (1 - label) * F.relu(self.margin - dist).pow(2)
18
+
19
+ return loss.mean(), dist
src/model.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ import torch.nn.functional as F
6
+
7
+ class EmbeddingNet(nn.Module):
8
+ def __init__(self, embedding_dim=128):
9
+ super().__init__()
10
+
11
+ # Pretrained ResNet-18, strip the final FC layer
12
+ backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
13
+ self.backbone = nn.Sequential(*list(backbone.children())[:-1]) # → [B, 512, 1, 1]
14
+
15
+ # Embedding head: 512 → 256 → 128, L2-normalised output
16
+ self.head = nn.Sequential(
17
+ nn.Linear(512, 256),
18
+ nn.BatchNorm1d(256),
19
+ nn.ReLU(inplace=True),
20
+ nn.Linear(256, embedding_dim),
21
+ )
22
+
23
+ def forward(self, x):
24
+ # Omniglot is grayscale — replicate channel to fake RGB for ResNet
25
+ if x.shape[1] == 1:
26
+ x = x.repeat(1, 3, 1, 1) # [B, 1, H, W] → [B, 3, H, W]
27
+
28
+ x = self.backbone(x) # [B, 512, 1, 1]
29
+ x = x.view(x.size(0), -1) # [B, 512]
30
+ x = self.head(x) # [B, 128]
31
+ x = F.normalize(x, p=2, dim=1) # L2 normalise → unit sphere
32
+ return x
33
+
34
+
35
+ class SiameseNet(nn.Module):
36
+ def __init__(self, embedding_dim=128):
37
+ super().__init__()
38
+ self.embedding_net = EmbeddingNet(embedding_dim)
39
+
40
+ def forward(self, img1, img2):
41
+ emb1 = self.embedding_net(img1)
42
+ emb2 = self.embedding_net(img2)
43
+ return emb1, emb2
44
+
45
+ def get_embedding(self, img):
46
+ return self.embedding_net(img)
src/run_training.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json, os
3
+ from torchvision.datasets import Omniglot
4
+ from torchvision import transforms
5
+ from torch.utils.data import DataLoader
6
+ import wandb
7
+
8
+ from model import SiameseNet
9
+ from loss import ContrastiveLoss
10
+ from dataset import SiamesePairDataset
11
+ from train import train_one_epoch, validate, save_checkpoint
12
+
13
+ # ── Config ────────────────────────────────────────────────────
14
+ CFG = {
15
+ "epochs" : 30,
16
+ "batch_size" : 32,
17
+ "lr" : 1e-3,
18
+ "embedding_dim" : 128,
19
+ "margin" : 1.0,
20
+ "num_workers" : 4,
21
+ "num_pairs_train": 10000,
22
+ "num_pairs_val" : 2000,
23
+ "data_root" : "../data",
24
+ "ckpt_dir" : "../checkpoints",
25
+ }
26
+
27
+ # ── WandB ─────────────────────────────────────────────────────
28
+ wandb.init(project="siamese-few-shot", name="run-01", config=CFG)
29
+
30
+ # ── Data ──────────────────────────────────────────────────────
31
+ MEAN, STD = [0.9220], [0.2256]
32
+
33
+ train_transform = transforms.Compose([
34
+ transforms.Grayscale(),
35
+ transforms.Resize((105, 105)),
36
+ transforms.RandomCrop(105, padding=8),
37
+ transforms.RandomHorizontalFlip(),
38
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize(MEAN, STD),
41
+ ])
42
+
43
+ eval_transform = transforms.Compose([
44
+ transforms.Grayscale(),
45
+ transforms.Resize((105, 105)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(MEAN, STD),
48
+ ])
49
+
50
+ bg = Omniglot(root=CFG["data_root"], background=True, download=True, transform=None)
51
+
52
+ with open(os.path.join(CFG["data_root"], "class_split.json")) as f:
53
+ split = json.load(f)
54
+
55
+ train_ds = SiamesePairDataset(bg, split["train"], transform=train_transform,
56
+ num_pairs=CFG["num_pairs_train"])
57
+ val_ds = SiamesePairDataset(bg, split["val"], transform=eval_transform,
58
+ num_pairs=CFG["num_pairs_val"])
59
+
60
+ train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,
61
+ num_workers=CFG["num_workers"], pin_memory=True)
62
+ val_loader = DataLoader(val_ds, batch_size=CFG["batch_size"], shuffle=False,
63
+ num_workers=CFG["num_workers"], pin_memory=True)
64
+
65
+ # ── Model / Loss / Optimiser ──────────────────────────────────
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+ model = SiameseNet(embedding_dim=CFG["embedding_dim"]).to(device)
68
+ criterion = ContrastiveLoss(margin=CFG["margin"])
69
+ optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])
70
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG["epochs"])
71
+
72
+ print(f"Training on : {device}")
73
+ print(f"Train pairs : {len(train_ds)} | Val pairs: {len(val_ds)}")
74
+
75
+ # ── Training loop ─────────────────────────────────────────────
76
+ best_val_loss = float("inf")
77
+
78
+ for epoch in range(1, CFG["epochs"] + 1):
79
+ train_loss, train_acc = train_one_epoch(model, train_loader, criterion,
80
+ optimizer, device, epoch)
81
+ val_loss, val_acc = validate(model, val_loader, criterion, device, epoch)
82
+ scheduler.step()
83
+
84
+ print(f"Epoch {epoch:02d} | "
85
+ f"train loss {train_loss:.4f} acc {train_acc*100:.1f}% | "
86
+ f"val loss {val_loss:.4f} acc {val_acc*100:.1f}%")
87
+
88
+ wandb.log({
89
+ "epoch" : epoch,
90
+ "train/loss" : train_loss,
91
+ "train/acc" : train_acc,
92
+ "val/loss" : val_loss,
93
+ "val/acc" : val_acc,
94
+ "lr" : scheduler.get_last_lr()[0],
95
+ })
96
+
97
+ # Save best checkpoint
98
+ if val_loss < best_val_loss:
99
+ best_val_loss = val_loss
100
+ save_checkpoint(model, optimizer, epoch, val_loss,
101
+ f"{CFG['ckpt_dir']}/best.pt")
102
+
103
+ # Save final checkpoint regardless
104
+ save_checkpoint(model, optimizer, CFG["epochs"], val_loss,
105
+ f"{CFG['ckpt_dir']}/final.pt")
106
+
107
+ wandb.finish()
108
+ print("Training complete.")
src/train.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.optim import Adam
4
+ from torch.optim.lr_scheduler import CosineAnnealingLR
5
+ import wandb
6
+ from tqdm import tqdm
7
+ import os
8
+
9
+ def train_one_epoch(model, loader, criterion, optimizer, device, epoch):
10
+ model.train()
11
+ total_loss, correct, total = 0.0, 0, 0
12
+
13
+ loop = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False)
14
+ for img1, img2, labels in loop:
15
+ img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
16
+
17
+ optimizer.zero_grad()
18
+ emb1, emb2 = model(img1, img2)
19
+ loss, dist = criterion(emb1, emb2, labels)
20
+ loss.backward()
21
+
22
+ # Gradient clipping — prevents exploding gradients early in training
23
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
24
+
25
+ optimizer.step()
26
+
27
+ # Accuracy: predict same-class if distance < 0.5
28
+ preds = (dist < 0.5).float()
29
+ correct += (preds == labels).sum().item()
30
+ total += labels.size(0)
31
+ total_loss += loss.item()
32
+
33
+ loop.set_postfix(loss=f"{loss.item():.4f}")
34
+
35
+ return total_loss / len(loader), correct / total
36
+
37
+
38
+ @torch.no_grad()
39
+ def validate(model, loader, criterion, device, epoch):
40
+ model.eval()
41
+ total_loss, correct, total = 0.0, 0, 0
42
+
43
+ loop = tqdm(loader, desc=f"Epoch {epoch} [val] ", leave=False)
44
+ for img1, img2, labels in loop:
45
+ img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
46
+
47
+ emb1, emb2 = model(img1, img2)
48
+ loss, dist = criterion(emb1, emb2, labels)
49
+
50
+ preds = (dist < 0.5).float()
51
+ correct += (preds == labels).sum().item()
52
+ total += labels.size(0)
53
+ total_loss += loss.item()
54
+
55
+ return total_loss / len(loader), correct / total
56
+
57
+
58
+ def save_checkpoint(model, optimizer, epoch, val_loss, path):
59
+ os.makedirs(os.path.dirname(path), exist_ok=True)
60
+ torch.save({
61
+ "epoch" : epoch,
62
+ "model_state": model.state_dict(),
63
+ "optim_state": optimizer.state_dict(),
64
+ "val_loss" : val_loss,
65
+ }, path)
66
+ print(f" Checkpoint saved → {path}")
67
+
68
+
69
+ def load_checkpoint(path, model, optimizer=None):
70
+ ckpt = torch.load(path)
71
+ model.load_state_dict(ckpt["model_state"])
72
+ if optimizer:
73
+ optimizer.load_state_dict(ckpt["optim_state"])
74
+ print(f" Resumed from epoch {ckpt['epoch']} (val_loss={ckpt['val_loss']:.4f})")
75
+ return ckpt["epoch"]