kannanmohan commited on
Commit
080763c
·
1 Parent(s): b4d9017

feat: add two-phase classifier

Browse files
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from PIL import Image
3
  import torch
4
  from timeit import default_timer as timer
@@ -7,38 +6,52 @@ from transformers import AutoModelForImageClassification, AutoImageProcessor
7
  import gradio as gr
8
  from pathlib import Path
9
 
 
 
 
10
 
11
- # TODO - check if we go without setting this
12
- # AutoImageProcessor.register("CustomVisionImageProcessorV2", CustomVisionImageProcessorV2)
13
-
14
- model = AutoModelForImageClassification.from_pretrained("models")
15
- processor = AutoImageProcessor.from_pretrained("models", use_fast=False, trust_remote_code=True)
16
 
17
- def predict(image:Image, model:torch.nn.Module=model, pre_processor=processor)-> Tuple[Dict, float]:
18
  """
19
- Predict the class of an image using a trained model.
 
 
20
  """
 
 
21
  start_time = timer()
22
- image_trans = processor(image)
23
-
24
- model.eval()
 
 
 
 
 
 
 
 
25
  with torch.inference_mode():
26
- logits = model(image_trans['pixel_values'].unsqueeze(0)).logits
27
- pred_probs = torch.softmax(logits, dim=1)
28
- pred_label_prob_map = {model.config.id2label[i]: pred_probs[0][i].item() for i in range(len(pred_probs[0]))}
29
  elapsed_time = round(timer() - start_time, 4)
30
  return pred_label_prob_map, elapsed_time
31
 
 
32
  example_images = list(Path("assets").glob("*.jpg"))
33
 
34
  demo = gr.Interface(fn=predict,
35
  inputs=gr.Image(type="pil"),
36
  outputs=[
37
  gr.Label(num_top_classes=5, label="Predicted class"),
38
- gr.Number(label="Prediction time"),
39
  ],
40
  examples=example_images,
41
  title="Image Classification App",
42
  description="Upload an image to predict the class of the image")
43
 
44
- demo.launch()
 
 
1
  from PIL import Image
2
  import torch
3
  from timeit import default_timer as timer
 
6
  import gradio as gr
7
  from pathlib import Path
8
 
9
+ DEVICE = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() and torch.backends.mps.is_built() else 'cpu')
10
+ BINARY_PATH = "./models/binary_class"
11
+ MULTICLASS_PATH = "./models/multi_class"
12
 
13
+ processor = AutoImageProcessor.from_pretrained("./models/", use_fast=False, trust_remote_code=True)
14
+ binary_model = AutoModelForImageClassification.from_pretrained(BINARY_PATH).to(DEVICE)
15
+ multi_class_model = AutoModelForImageClassification.from_pretrained(MULTICLASS_PATH).to(DEVICE)
 
 
16
 
17
+ def predict(image:Image)-> Tuple[Dict, float]:
18
  """
19
+ 1. Binary pass -> yes / no
20
+ 2. Multi pass -> only if binary == yes
21
+ Returns: (label_dict | "Not a food item", total_time)
22
  """
23
+ if image is None:
24
+ return "no image provided", 0.0
25
  start_time = timer()
26
+ input = processor(images=image, return_tensors="pt").to(DEVICE)
27
+ binary_model.eval()
28
+ with torch.inference_mode():
29
+ logits = binary_model(**input).logits.squeeze()
30
+ pred_probs = torch.softmax(logits, dim=-1)
31
+ pred_food_prob = pred_probs[binary_model.config.label2id.get("food", 1)].item()
32
+ if pred_food_prob < 0.5:
33
+ elapsed = round(timer() - start_time, 4)
34
+ return "Not a food item", elapsed
35
+
36
+ multi_class_model.eval()
37
  with torch.inference_mode():
38
+ logits = multi_class_model(**input).logits
39
+ pred_probs = torch.softmax(logits, dim=-1)
40
+ pred_label_prob_map = {multi_class_model.config.id2label[i]: pred_probs[0][i].item() for i in range(len(pred_probs[0]))}
41
  elapsed_time = round(timer() - start_time, 4)
42
  return pred_label_prob_map, elapsed_time
43
 
44
+ # get the example images from assets folder
45
  example_images = list(Path("assets").glob("*.jpg"))
46
 
47
  demo = gr.Interface(fn=predict,
48
  inputs=gr.Image(type="pil"),
49
  outputs=[
50
  gr.Label(num_top_classes=5, label="Predicted class"),
51
+ gr.Number(label="Prediction time (sec)"),
52
  ],
53
  examples=example_images,
54
  title="Image Classification App",
55
  description="Upload an image to predict the class of the image")
56
 
57
+ demo.launch()
assets/croque_madame_0.jpg DELETED
Binary file (46.6 kB)
 
assets/img_0_0.jpg ADDED
assets/img_0_1.jpg ADDED
assets/img_1_2.jpg ADDED
assets/pancakes_1.jpg DELETED
Binary file (45 kB)
 
assets/shrimp_and_grits_2.jpg DELETED
Binary file (54.5 kB)
 
models/binary_class/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTForImageClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "auto_map": {
7
+ "AutoModelForImageClassification": "transformers.ViTForImageClassification"
8
+ },
9
+ "encoder_stride": 16,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.0,
12
+ "hidden_size": 768,
13
+ "id2label": {
14
+ "0": "non-food",
15
+ "1": "food"
16
+ },
17
+ "image_size": 224,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 3072,
20
+ "label2id": {
21
+ "food": 1,
22
+ "non-food": 0
23
+ },
24
+ "layer_norm_eps": 1e-12,
25
+ "model_type": "vit",
26
+ "num_attention_heads": 12,
27
+ "num_channels": 3,
28
+ "num_hidden_layers": 12,
29
+ "patch_size": 16,
30
+ "pooler_act": "tanh",
31
+ "pooler_output_size": 768,
32
+ "problem_type": "single_label_classification",
33
+ "qkv_bias": true,
34
+ "torch_dtype": "float32",
35
+ "transformers_version": "4.55.2"
36
+ }
models/{training_args.bin → binary_class/training_args.bin} RENAMED
Binary files a/models/training_args.bin and b/models/binary_class/training_args.bin differ
 
models/{config.json → multi_class/config.json} RENAMED
File without changes
models/multi_class/training_args.bin ADDED
Binary file (5.78 kB). View file
 
models/pre_processor.py DELETED
@@ -1,111 +0,0 @@
1
-
2
- from typing import Any, Dict, List
3
- import torch
4
- from PIL import Image as PILImage
5
- from torchvision.transforms import v2
6
- from transformers import BaseImageProcessor
7
- from collections.abc import Mapping
8
-
9
- class CustomVisionImageProcessorV2(BaseImageProcessor):
10
- """
11
- ViT-B/16-224 preprocessing for Huggingface datasets.
12
- Works with:
13
- - dataset.map(..., batched=True/False)
14
- - dataset.set_transform
15
- """
16
- # run-time hint used by the Hugging-Face Trainer / pipeline / collator to decide which fields to feed into the model’s forward pass
17
- model_input_names = ["pixel_values"]
18
-
19
- def __init__( self, size: int = 224, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), train: bool = True, **kwargs):
20
- super().__init__(**kwargs)
21
- self.size = size
22
- self.mean = mean
23
- self.std = std
24
- self.train = train
25
- # self.auto_map = { # add an entry in the processor config file for auto registring the processor
26
- # "AutoImageProcessor": "image_processor.CustomVisionImageProcessorV2",
27
- # }
28
- # Create transform lazily so it is NOT serialised when we save using ImageProcessor.save_pretrained
29
- @property
30
- def transform(self):
31
- if not hasattr(self, "_transform"):
32
- if self.train:
33
- self._transform = v2.Compose(
34
- [
35
- v2.ToImage(),
36
- v2.RandomResizedCrop(
37
- size=(self.size, self.size),
38
- scale=(0.8, 1.0),
39
- antialias=True,
40
- ),
41
- v2.RandomHorizontalFlip(p=0.5),
42
- v2.ToDtype(torch.float32, scale=True),
43
- v2.Normalize(mean=self.mean, std=self.std),
44
- ]
45
- )
46
- else:
47
- self._transform = v2.Compose(
48
- [
49
- v2.ToImage(),
50
- v2.Resize(size=(self.size, self.size), antialias=True),
51
- v2.CenterCrop(size=(self.size, self.size)),
52
- v2.ToDtype(torch.float32, scale=True),
53
- v2.Normalize(mean=self.mean, std=self.std),
54
- ]
55
- )
56
- return self._transform
57
-
58
- def __call__(self,
59
- examples: PILImage.Image | List[PILImage.Image] | Dict[str, Any],
60
- image_feature_name: str="image",
61
- label_feature_name: str="label") -> Dict[str, torch.Tensor]:
62
- """
63
- Accepts:
64
- * a single PIL.Image
65
- * a list[PIL.Image]
66
- * a dict with keys:
67
- 'image' : PIL.Image or list[PIL.Image]
68
- 'label' : single int or list[int] (optional)
69
-
70
- Returns:
71
- dict with:
72
- 'pixel_values' : tensor (C, H, W) or (N, C, H, W)
73
- 'labels' : tensor (long) or (N,)
74
- """
75
- if isinstance(examples, PILImage.Image): # single PIL image → wrap into list
76
- images, labels = [examples], None
77
- elif isinstance(examples, list): # list of PIL images
78
- images, labels = examples, None
79
- elif isinstance(examples, Mapping): # dict (single example or LazyBatch)
80
- images = examples[image_feature_name]
81
- labels = examples.get(label_feature_name)
82
- if isinstance(images, PILImage.Image): # single example
83
- images = [images]
84
- labels = [labels] if labels is not None else None
85
- # else images is already a list (LazyBatch)
86
- else:
87
- raise TypeError(f"Expected PIL.Image, list[PIL.Image] or dict, got {type(examples)}")
88
-
89
- pixel_values = torch.stack([self.transform(img) for img in images])
90
- if pixel_values.shape[0] == 1: # squeeze singleton batch dimension when we only processed one image
91
- pixel_values = pixel_values.squeeze(0)
92
-
93
- out = {"pixel_values": pixel_values}
94
- if labels is not None:
95
- labels_tensor = torch.tensor(labels, dtype=torch.long)
96
- if labels_tensor.shape[0] == 1: # squeeze singleton batch dimension when we only processed one label
97
- labels_tensor = labels_tensor.squeeze(0)
98
- out["labels"] = labels_tensor
99
- return out
100
-
101
- def to_dict(self) -> Dict[str, Any]:
102
- cfg = super().to_dict()
103
- cfg.update(
104
- dict(
105
- auto_map={
106
- "AutoImageProcessor": "pre_processor.CustomVisionImageProcessorV2",
107
- # "AutoImageProcessor_fast": "pre_processor.CustomVisionImageProcessorV3Fast",
108
- },
109
- )
110
- )
111
- return cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/preprocessor_config.json CHANGED
@@ -1,18 +1,23 @@
1
  {
2
- "auto_map": {
3
- "AutoImageProcessor": "pre_processor.CustomVisionImageProcessorV2"
4
- },
5
- "image_processor_type": "CustomVisionImageProcessorV2",
6
- "mean": [
7
  0.5,
8
  0.5,
9
  0.5
10
  ],
11
- "size": 224,
12
- "std": [
13
  0.5,
14
  0.5,
15
  0.5
16
  ],
17
- "train": false
 
 
 
 
 
18
  }
 
1
  {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
  0.5,
8
  0.5,
9
  0.5
10
  ],
11
+ "image_processor_type": "ViTImageProcessor",
12
+ "image_std": [
13
  0.5,
14
  0.5,
15
  0.5
16
  ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 224,
21
+ "width": 224
22
+ }
23
  }
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  torch==2.8.0
2
- torchvision==0.23.0
3
  gradio==5.42.0
4
  transformers==4.55.2
5
  evaluate==0.4.5
 
1
  torch==2.8.0
 
2
  gradio==5.42.0
3
  transformers==4.55.2
4
  evaluate==0.4.5