dchen0 commited on
Commit
2e97025
·
verified ·
1 Parent(s): 61f32ea

Add merged model + processor

Browse files
Files changed (2) hide show
  1. handler.py +25 -1
  2. requirements.txt +1 -4
handler.py CHANGED
@@ -5,10 +5,34 @@ import io
5
  from typing import Any, Dict
6
 
7
  import torch
 
8
  from PIL import Image
9
  from transformers import AutoImageProcessor, Dinov2ForImageClassification
10
 
11
- from train_model import get_inference_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class EndpointHandler:
 
5
  from typing import Any, Dict
6
 
7
  import torch
8
+ import torchvision.transforms as T
9
  from PIL import Image
10
  from transformers import AutoImageProcessor, Dinov2ForImageClassification
11
 
12
+
13
+ def get_inference_transform(processor: AutoImageProcessor, size: int):
14
+ """Get the raw validation transform for direct inference on PIL images."""
15
+ normalize = T.Normalize(mean=processor.image_mean, std=processor.image_std)
16
+
17
+ to_rgb = T.Lambda(lambda img: img.convert('RGB'))
18
+
19
+ def pad_to_square(img):
20
+ w, h = img.size
21
+ max_size = max(w, h)
22
+ pad_w = (max_size - w) // 2
23
+ pad_h = (max_size - h) // 2
24
+ padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h)
25
+ return T.Pad(padding, fill=0)(img)
26
+
27
+ aug = T.Compose([
28
+ to_rgb,
29
+ pad_to_square,
30
+ T.Resize(size),
31
+ T.ToTensor(),
32
+ normalize
33
+ ])
34
+
35
+ return aug
36
 
37
 
38
  class EndpointHandler:
requirements.txt CHANGED
@@ -1,5 +1,2 @@
1
  torchvision>=0.19
2
- Pillow>=10
3
- datasets>=2.19
4
- peft>=0.10
5
- safetensors>=0.4
 
1
  torchvision>=0.19
2
+ Pillow>=10