Transformers
jespark commited on
Commit
7ebc30a
·
verified ·
1 Parent(s): 227bf21

Upload dataprocessor_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataprocessor_hf.py +53 -0
dataprocessor_hf.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataloader as dl
2
+ import torch
3
+ import argparse
4
+ import transformers
5
+ import PIL.Image as Image
6
+ from typing import Union, List
7
+
8
+ from transformers.image_processing_utils import BaseImageProcessor
9
+ from transformers.utils import PushToHubMixin
10
+
11
+ class CommForImageProcessor(BaseImageProcessor, PushToHubMixin):
12
+ """
13
+ Image processor for Community Forensics VIT model. Processes PIL images and returns PyTorch tensors.
14
+ """
15
+ image_processor_type = "commfor_image_processor"
16
+ model_input_names = ["pixel_values"]
17
+
18
+ def __init__(self, size=384, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.size = size
21
+ assert self.size in [224, 384], f"Unsupported size: {self.size}. Supported sizes are 224 and 384."
22
+
23
+ def preprocess(
24
+ self,
25
+ images: Union[Image.Image, List[Image.Image]],
26
+ mode: str = "test",
27
+ **kwargs
28
+ ):
29
+ """
30
+ Preprocess the input images to PyTorch tensors.
31
+ """
32
+ assert mode in ["test", "train"], f"Unsupported mode: {mode}. Supported modes are 'test' and 'train'."
33
+ assert isinstance(images, (Image.Image, list)), "Input must be a PIL Image or a list of PIL Images."
34
+ if isinstance(images, Image.Image):
35
+ images = [images]
36
+
37
+ args = argparse.Namespace()
38
+ args.input_size = self.size
39
+ args.rsa_ops="JPEGinMemory,RandomResizeWithRandomIntpl,RandomCrop,RandomHorizontalFlip,RandomVerticalFlip,RRCWithRandomIntpl,RandomRotation,RandomTranslate,RandomShear,RandomPadding,RandomCutout"
40
+ args.rsa_min_num_ops='0'
41
+ args.rsa_max_num_ops='2'
42
+
43
+ transform = dl.get_transform(args, mode=mode)
44
+
45
+ processed_images = [transform(image) for image in images] # the output would be tensors
46
+ if len(processed_images) == 1:
47
+ return {"pixel_values": processed_images[0]}
48
+ else:
49
+ return {"pixel_values": torch.stack(processed_images)}
50
+
51
+
52
+
53
+