|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- vision |
|
|
- image-classification |
|
|
--- |
|
|
|
|
|
# Kindwise Router Classifier |
|
|
|
|
|
This model classifies images based on their content, acting as a "router" to |
|
|
direct requests to the correct Kindwise service. It automatically detects |
|
|
whether an image contains human, insect, mushroom, or plant. |
|
|
|
|
|
Th model is intended to be the first step in an image processing pipeline. |
|
|
Instead of having each specialized service (e.g., insect, plant classification) |
|
|
analyze every image, this model quickly determines the image's category. This |
|
|
reduces latency and optimizes system resources. |
|
|
|
|
|
## Technical Details and Formats |
|
|
|
|
|
The model is available in two optimized formats for easy deployment: |
|
|
- **TorchScript**: Optimized for production environments and server-side applications where performance and low latency are critical. |
|
|
- **TensorFlow Lite**: Perfect for mobile devices and edge computing, where efficiency and minimal model size are key. |
|
|
|
|
|
## Usage |
|
|
|
|
|
Here is how to use this model to classify an image into one of the basic classes: |
|
|
|
|
|
### PyTorch |
|
|
|
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import PIL.Image |
|
|
import torch |
|
|
import torchvision |
|
|
|
|
|
DEVICE_NAME = 'cuda:0' |
|
|
MODEL_PATH = hf_hub_download('kindwise/router.small', 'model.traced.pt') |
|
|
CLASSES_PATH = hf_hub_download('kindwise/router.small', 'classes.txt') |
|
|
IMAGE_PATH = '/tmp/photo.jpg' |
|
|
|
|
|
with open(CLASSES_PATH) as f: |
|
|
CLASSES = [line.strip() for line in f] |
|
|
MODEL = torch.jit.load(MODEL_PATH).eval().to(DEVICE_NAME) |
|
|
|
|
|
|
|
|
def resize_crop(image_data: np.ndarray, target_size: int = 480) -> np.ndarray | None: |
|
|
height, width, _ = image_data.shape |
|
|
# Determine the size of the square crop |
|
|
crop_size = min(height, width) |
|
|
# Calculate coordinates for center crop |
|
|
start_x = (width - crop_size) // 2 |
|
|
start_y = (height - crop_size) // 2 |
|
|
# Perform center crop |
|
|
cropped_img = image_data[ |
|
|
start_y : start_y + crop_size, |
|
|
start_x : start_x + crop_size |
|
|
] |
|
|
# Resize cropped image to target size |
|
|
return cv2.resize( |
|
|
cropped_img, |
|
|
(target_size, target_size), |
|
|
interpolation=cv2.INTER_AREA, |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
image_array = np.array(PIL.Image.open(IMAGE_PATH)) |
|
|
image_array_resized = resize_crop(image_array) |
|
|
image_tensor = torchvision.transforms.functional.to_tensor(image_array_resized).to(DEVICE_NAME) |
|
|
prediction = MODEL(image_tensor.unsqueeze(0)).squeeze(0).cpu().numpy() |
|
|
for i in (-prediction).argsort(): |
|
|
print(f'{CLASSES[i]:>10}: {100 * prediction[i]:.1f}%') |
|
|
``` |
|
|
|
|
|
Output: |
|
|
``` |
|
|
plant: 91.3% |
|
|
unhealthy_plant: 53.3% |
|
|
crop: 16.2% |
|
|
insect: 0.4% |
|
|
human: 0.1% |
|
|
mushroom: 0.0% |
|
|
``` |
|
|
|
|
|
### |
|
|
|
|
|
### TensorFlow Lite |
|
|
|
|
|
```python |
|
|
from huggingface_hub import hf_hub_download |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
|
|
|
MODEL_PATH = hf_hub_download('kindwise/router.small', 'model.tflite') # or model.optimized.tflite |
|
|
CLASSES_PATH = hf_hub_download('kindwise/router.small', 'classes.txt') |
|
|
|
|
|
with open(CLASSES_PATH) as f: |
|
|
CLASSES = [line.strip() for line in f] |
|
|
INTERPRETER = tf.lite.Interpreter(model_path=MODEL_PATH) |
|
|
INTERPRETER.allocate_tensors() |
|
|
|
|
|
image_array_resized = ... # see the previous example |
|
|
tf_input = np.expand_dims( # add batch dimension |
|
|
(image_array_resized / 255).astype(np.float32), # image values in [0..1] |
|
|
0, |
|
|
) |
|
|
input_details = INTERPRETER.get_input_details() |
|
|
output_details = INTERPRETER.get_output_details() |
|
|
INTERPRETER.set_tensor( |
|
|
input_details[0]['index'], |
|
|
tf_input, |
|
|
) |
|
|
INTERPRETER.invoke() |
|
|
logits = INTERPRETER.get_tensor(output_details[0]['index'])[0] |
|
|
prediction = tf.nn.sigmoid(logits).numpy() |
|
|
for i in (-prediction).argsort(): |
|
|
print(f'{CLASSES[i]:>10}: {100 * prediction[i]:.1f}%') |
|
|
``` |
|
|
|
|
|
Output: |
|
|
``` |
|
|
plant: 91.3% |
|
|
unhealthy_plant: 53.3% |
|
|
crop: 16.2% |
|
|
insect: 0.4% |
|
|
human: 0.1% |
|
|
mushroom: 0.0% |
|
|
``` |
|
|
|