Spaces:
Runtime error
Runtime error
Commit ·
080763c
1
Parent(s): b4d9017
feat: add two-phase classifier
Browse files- app.py +29 -16
- assets/croque_madame_0.jpg +0 -0
- assets/img_0_0.jpg +0 -0
- assets/img_0_1.jpg +0 -0
- assets/img_1_2.jpg +0 -0
- assets/pancakes_1.jpg +0 -0
- assets/shrimp_and_grits_2.jpg +0 -0
- models/binary_class/config.json +36 -0
- models/{training_args.bin → binary_class/training_args.bin} +0 -0
- models/{config.json → multi_class/config.json} +0 -0
- models/multi_class/training_args.bin +0 -0
- models/pre_processor.py +0 -111
- models/preprocessor_config.json +13 -8
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
from PIL import Image
|
| 3 |
import torch
|
| 4 |
from timeit import default_timer as timer
|
|
@@ -7,38 +6,52 @@ from transformers import AutoModelForImageClassification, AutoImageProcessor
|
|
| 7 |
import gradio as gr
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
model = AutoModelForImageClassification.from_pretrained("models")
|
| 15 |
-
processor = AutoImageProcessor.from_pretrained("models", use_fast=False, trust_remote_code=True)
|
| 16 |
|
| 17 |
-
def predict(image:Image
|
| 18 |
"""
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
"""
|
|
|
|
|
|
|
| 21 |
start_time = timer()
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
with torch.inference_mode():
|
| 26 |
-
logits =
|
| 27 |
-
pred_probs = torch.softmax(logits, dim=1)
|
| 28 |
-
pred_label_prob_map = {
|
| 29 |
elapsed_time = round(timer() - start_time, 4)
|
| 30 |
return pred_label_prob_map, elapsed_time
|
| 31 |
|
|
|
|
| 32 |
example_images = list(Path("assets").glob("*.jpg"))
|
| 33 |
|
| 34 |
demo = gr.Interface(fn=predict,
|
| 35 |
inputs=gr.Image(type="pil"),
|
| 36 |
outputs=[
|
| 37 |
gr.Label(num_top_classes=5, label="Predicted class"),
|
| 38 |
-
gr.Number(label="Prediction time"),
|
| 39 |
],
|
| 40 |
examples=example_images,
|
| 41 |
title="Image Classification App",
|
| 42 |
description="Upload an image to predict the class of the image")
|
| 43 |
|
| 44 |
-
demo.launch()
|
|
|
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
import torch
|
| 3 |
from timeit import default_timer as timer
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() and torch.backends.mps.is_built() else 'cpu')
|
| 10 |
+
BINARY_PATH = "./models/binary_class"
|
| 11 |
+
MULTICLASS_PATH = "./models/multi_class"
|
| 12 |
|
| 13 |
+
processor = AutoImageProcessor.from_pretrained("./models/", use_fast=False, trust_remote_code=True)
|
| 14 |
+
binary_model = AutoModelForImageClassification.from_pretrained(BINARY_PATH).to(DEVICE)
|
| 15 |
+
multi_class_model = AutoModelForImageClassification.from_pretrained(MULTICLASS_PATH).to(DEVICE)
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
def predict(image:Image)-> Tuple[Dict, float]:
|
| 18 |
"""
|
| 19 |
+
1. Binary pass -> yes / no
|
| 20 |
+
2. Multi pass -> only if binary == yes
|
| 21 |
+
Returns: (label_dict | "Not a food item", total_time)
|
| 22 |
"""
|
| 23 |
+
if image is None:
|
| 24 |
+
return "no image provided", 0.0
|
| 25 |
start_time = timer()
|
| 26 |
+
input = processor(images=image, return_tensors="pt").to(DEVICE)
|
| 27 |
+
binary_model.eval()
|
| 28 |
+
with torch.inference_mode():
|
| 29 |
+
logits = binary_model(**input).logits.squeeze()
|
| 30 |
+
pred_probs = torch.softmax(logits, dim=-1)
|
| 31 |
+
pred_food_prob = pred_probs[binary_model.config.label2id.get("food", 1)].item()
|
| 32 |
+
if pred_food_prob < 0.5:
|
| 33 |
+
elapsed = round(timer() - start_time, 4)
|
| 34 |
+
return "Not a food item", elapsed
|
| 35 |
+
|
| 36 |
+
multi_class_model.eval()
|
| 37 |
with torch.inference_mode():
|
| 38 |
+
logits = multi_class_model(**input).logits
|
| 39 |
+
pred_probs = torch.softmax(logits, dim=-1)
|
| 40 |
+
pred_label_prob_map = {multi_class_model.config.id2label[i]: pred_probs[0][i].item() for i in range(len(pred_probs[0]))}
|
| 41 |
elapsed_time = round(timer() - start_time, 4)
|
| 42 |
return pred_label_prob_map, elapsed_time
|
| 43 |
|
| 44 |
+
# get the example images from assets folder
|
| 45 |
example_images = list(Path("assets").glob("*.jpg"))
|
| 46 |
|
| 47 |
demo = gr.Interface(fn=predict,
|
| 48 |
inputs=gr.Image(type="pil"),
|
| 49 |
outputs=[
|
| 50 |
gr.Label(num_top_classes=5, label="Predicted class"),
|
| 51 |
+
gr.Number(label="Prediction time (sec)"),
|
| 52 |
],
|
| 53 |
examples=example_images,
|
| 54 |
title="Image Classification App",
|
| 55 |
description="Upload an image to predict the class of the image")
|
| 56 |
|
| 57 |
+
demo.launch()
|
assets/croque_madame_0.jpg
DELETED
|
Binary file (46.6 kB)
|
|
|
assets/img_0_0.jpg
ADDED
|
assets/img_0_1.jpg
ADDED
|
assets/img_1_2.jpg
ADDED
|
assets/pancakes_1.jpg
DELETED
|
Binary file (45 kB)
|
|
|
assets/shrimp_and_grits_2.jpg
DELETED
|
Binary file (54.5 kB)
|
|
|
models/binary_class/config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"ViTForImageClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.0,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoModelForImageClassification": "transformers.ViTForImageClassification"
|
| 8 |
+
},
|
| 9 |
+
"encoder_stride": 16,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.0,
|
| 12 |
+
"hidden_size": 768,
|
| 13 |
+
"id2label": {
|
| 14 |
+
"0": "non-food",
|
| 15 |
+
"1": "food"
|
| 16 |
+
},
|
| 17 |
+
"image_size": 224,
|
| 18 |
+
"initializer_range": 0.02,
|
| 19 |
+
"intermediate_size": 3072,
|
| 20 |
+
"label2id": {
|
| 21 |
+
"food": 1,
|
| 22 |
+
"non-food": 0
|
| 23 |
+
},
|
| 24 |
+
"layer_norm_eps": 1e-12,
|
| 25 |
+
"model_type": "vit",
|
| 26 |
+
"num_attention_heads": 12,
|
| 27 |
+
"num_channels": 3,
|
| 28 |
+
"num_hidden_layers": 12,
|
| 29 |
+
"patch_size": 16,
|
| 30 |
+
"pooler_act": "tanh",
|
| 31 |
+
"pooler_output_size": 768,
|
| 32 |
+
"problem_type": "single_label_classification",
|
| 33 |
+
"qkv_bias": true,
|
| 34 |
+
"torch_dtype": "float32",
|
| 35 |
+
"transformers_version": "4.55.2"
|
| 36 |
+
}
|
models/{training_args.bin → binary_class/training_args.bin}
RENAMED
|
Binary files a/models/training_args.bin and b/models/binary_class/training_args.bin differ
|
|
|
models/{config.json → multi_class/config.json}
RENAMED
|
File without changes
|
models/multi_class/training_args.bin
ADDED
|
Binary file (5.78 kB). View file
|
|
|
models/pre_processor.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
from typing import Any, Dict, List
|
| 3 |
-
import torch
|
| 4 |
-
from PIL import Image as PILImage
|
| 5 |
-
from torchvision.transforms import v2
|
| 6 |
-
from transformers import BaseImageProcessor
|
| 7 |
-
from collections.abc import Mapping
|
| 8 |
-
|
| 9 |
-
class CustomVisionImageProcessorV2(BaseImageProcessor):
|
| 10 |
-
"""
|
| 11 |
-
ViT-B/16-224 preprocessing for Huggingface datasets.
|
| 12 |
-
Works with:
|
| 13 |
-
- dataset.map(..., batched=True/False)
|
| 14 |
-
- dataset.set_transform
|
| 15 |
-
"""
|
| 16 |
-
# run-time hint used by the Hugging-Face Trainer / pipeline / collator to decide which fields to feed into the model’s forward pass
|
| 17 |
-
model_input_names = ["pixel_values"]
|
| 18 |
-
|
| 19 |
-
def __init__( self, size: int = 224, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), train: bool = True, **kwargs):
|
| 20 |
-
super().__init__(**kwargs)
|
| 21 |
-
self.size = size
|
| 22 |
-
self.mean = mean
|
| 23 |
-
self.std = std
|
| 24 |
-
self.train = train
|
| 25 |
-
# self.auto_map = { # add an entry in the processor config file for auto registring the processor
|
| 26 |
-
# "AutoImageProcessor": "image_processor.CustomVisionImageProcessorV2",
|
| 27 |
-
# }
|
| 28 |
-
# Create transform lazily so it is NOT serialised when we save using ImageProcessor.save_pretrained
|
| 29 |
-
@property
|
| 30 |
-
def transform(self):
|
| 31 |
-
if not hasattr(self, "_transform"):
|
| 32 |
-
if self.train:
|
| 33 |
-
self._transform = v2.Compose(
|
| 34 |
-
[
|
| 35 |
-
v2.ToImage(),
|
| 36 |
-
v2.RandomResizedCrop(
|
| 37 |
-
size=(self.size, self.size),
|
| 38 |
-
scale=(0.8, 1.0),
|
| 39 |
-
antialias=True,
|
| 40 |
-
),
|
| 41 |
-
v2.RandomHorizontalFlip(p=0.5),
|
| 42 |
-
v2.ToDtype(torch.float32, scale=True),
|
| 43 |
-
v2.Normalize(mean=self.mean, std=self.std),
|
| 44 |
-
]
|
| 45 |
-
)
|
| 46 |
-
else:
|
| 47 |
-
self._transform = v2.Compose(
|
| 48 |
-
[
|
| 49 |
-
v2.ToImage(),
|
| 50 |
-
v2.Resize(size=(self.size, self.size), antialias=True),
|
| 51 |
-
v2.CenterCrop(size=(self.size, self.size)),
|
| 52 |
-
v2.ToDtype(torch.float32, scale=True),
|
| 53 |
-
v2.Normalize(mean=self.mean, std=self.std),
|
| 54 |
-
]
|
| 55 |
-
)
|
| 56 |
-
return self._transform
|
| 57 |
-
|
| 58 |
-
def __call__(self,
|
| 59 |
-
examples: PILImage.Image | List[PILImage.Image] | Dict[str, Any],
|
| 60 |
-
image_feature_name: str="image",
|
| 61 |
-
label_feature_name: str="label") -> Dict[str, torch.Tensor]:
|
| 62 |
-
"""
|
| 63 |
-
Accepts:
|
| 64 |
-
* a single PIL.Image
|
| 65 |
-
* a list[PIL.Image]
|
| 66 |
-
* a dict with keys:
|
| 67 |
-
'image' : PIL.Image or list[PIL.Image]
|
| 68 |
-
'label' : single int or list[int] (optional)
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
dict with:
|
| 72 |
-
'pixel_values' : tensor (C, H, W) or (N, C, H, W)
|
| 73 |
-
'labels' : tensor (long) or (N,)
|
| 74 |
-
"""
|
| 75 |
-
if isinstance(examples, PILImage.Image): # single PIL image → wrap into list
|
| 76 |
-
images, labels = [examples], None
|
| 77 |
-
elif isinstance(examples, list): # list of PIL images
|
| 78 |
-
images, labels = examples, None
|
| 79 |
-
elif isinstance(examples, Mapping): # dict (single example or LazyBatch)
|
| 80 |
-
images = examples[image_feature_name]
|
| 81 |
-
labels = examples.get(label_feature_name)
|
| 82 |
-
if isinstance(images, PILImage.Image): # single example
|
| 83 |
-
images = [images]
|
| 84 |
-
labels = [labels] if labels is not None else None
|
| 85 |
-
# else images is already a list (LazyBatch)
|
| 86 |
-
else:
|
| 87 |
-
raise TypeError(f"Expected PIL.Image, list[PIL.Image] or dict, got {type(examples)}")
|
| 88 |
-
|
| 89 |
-
pixel_values = torch.stack([self.transform(img) for img in images])
|
| 90 |
-
if pixel_values.shape[0] == 1: # squeeze singleton batch dimension when we only processed one image
|
| 91 |
-
pixel_values = pixel_values.squeeze(0)
|
| 92 |
-
|
| 93 |
-
out = {"pixel_values": pixel_values}
|
| 94 |
-
if labels is not None:
|
| 95 |
-
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
| 96 |
-
if labels_tensor.shape[0] == 1: # squeeze singleton batch dimension when we only processed one label
|
| 97 |
-
labels_tensor = labels_tensor.squeeze(0)
|
| 98 |
-
out["labels"] = labels_tensor
|
| 99 |
-
return out
|
| 100 |
-
|
| 101 |
-
def to_dict(self) -> Dict[str, Any]:
|
| 102 |
-
cfg = super().to_dict()
|
| 103 |
-
cfg.update(
|
| 104 |
-
dict(
|
| 105 |
-
auto_map={
|
| 106 |
-
"AutoImageProcessor": "pre_processor.CustomVisionImageProcessorV2",
|
| 107 |
-
# "AutoImageProcessor_fast": "pre_processor.CustomVisionImageProcessorV3Fast",
|
| 108 |
-
},
|
| 109 |
-
)
|
| 110 |
-
)
|
| 111 |
-
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/preprocessor_config.json
CHANGED
|
@@ -1,18 +1,23 @@
|
|
| 1 |
{
|
| 2 |
-
"
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
"
|
| 6 |
-
"
|
| 7 |
0.5,
|
| 8 |
0.5,
|
| 9 |
0.5
|
| 10 |
],
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
0.5,
|
| 14 |
0.5,
|
| 15 |
0.5
|
| 16 |
],
|
| 17 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"do_convert_rgb": null,
|
| 3 |
+
"do_normalize": true,
|
| 4 |
+
"do_rescale": true,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"image_mean": [
|
| 7 |
0.5,
|
| 8 |
0.5,
|
| 9 |
0.5
|
| 10 |
],
|
| 11 |
+
"image_processor_type": "ViTImageProcessor",
|
| 12 |
+
"image_std": [
|
| 13 |
0.5,
|
| 14 |
0.5,
|
| 15 |
0.5
|
| 16 |
],
|
| 17 |
+
"resample": 2,
|
| 18 |
+
"rescale_factor": 0.00392156862745098,
|
| 19 |
+
"size": {
|
| 20 |
+
"height": 224,
|
| 21 |
+
"width": 224
|
| 22 |
+
}
|
| 23 |
}
|
requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
torch==2.8.0
|
| 2 |
-
torchvision==0.23.0
|
| 3 |
gradio==5.42.0
|
| 4 |
transformers==4.55.2
|
| 5 |
evaluate==0.4.5
|
|
|
|
| 1 |
torch==2.8.0
|
|
|
|
| 2 |
gradio==5.42.0
|
| 3 |
transformers==4.55.2
|
| 4 |
evaluate==0.4.5
|