Update README.md
Browse files
README.md
CHANGED
|
@@ -290,8 +290,7 @@ The model categorizes images into the following 126 classes:
|
|
| 290 |
|
| 291 |
```python
|
| 292 |
import gradio as gr
|
| 293 |
-
from transformers import AutoImageProcessor
|
| 294 |
-
from transformers import SiglipForImageClassification
|
| 295 |
from transformers.image_utils import load_image
|
| 296 |
from PIL import Image
|
| 297 |
import torch
|
|
@@ -302,14 +301,21 @@ model = SiglipForImageClassification.from_pretrained(model_name)
|
|
| 302 |
processor = AutoImageProcessor.from_pretrained(model_name)
|
| 303 |
|
| 304 |
def sketch_classification(image):
|
| 305 |
-
|
| 306 |
-
|
|
|
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
with torch.no_grad():
|
| 309 |
outputs = model(**inputs)
|
| 310 |
logits = outputs.logits
|
|
|
|
| 311 |
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
|
| 312 |
|
|
|
|
| 313 |
labels = {
|
| 314 |
"0": "aircraft_carrier", "1": "alarm_clock", "2": "ant", "3": "anvil", "4": "asparagus",
|
| 315 |
"5": "axe", "6": "banana", "7": "basket", "8": "bathtub", "9": "bear",
|
|
@@ -339,20 +345,21 @@ def sketch_classification(image):
|
|
| 339 |
"122": "vase", "123": "watermelon", "124": "whale", "125": "zebra"
|
| 340 |
}
|
| 341 |
|
|
|
|
| 342 |
predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
|
| 343 |
return predictions
|
| 344 |
|
| 345 |
# Create Gradio interface
|
| 346 |
iface = gr.Interface(
|
| 347 |
fn=sketch_classification,
|
| 348 |
-
inputs=gr.Image(type
|
| 349 |
-
outputs=gr.Label(label
|
| 350 |
-
title
|
| 351 |
-
description
|
| 352 |
)
|
| 353 |
|
| 354 |
# Launch the app
|
| 355 |
-
if __name__ ==
|
| 356 |
iface.launch()
|
| 357 |
```
|
| 358 |
|
|
|
|
| 290 |
|
| 291 |
```python
|
| 292 |
import gradio as gr
|
| 293 |
+
from transformers import AutoImageProcessor, SiglipForImageClassification
|
|
|
|
| 294 |
from transformers.image_utils import load_image
|
| 295 |
from PIL import Image
|
| 296 |
import torch
|
|
|
|
| 301 |
processor = AutoImageProcessor.from_pretrained(model_name)
|
| 302 |
|
| 303 |
def sketch_classification(image):
|
| 304 |
+
"""Predicts the sketch category for an input image."""
|
| 305 |
+
# Convert the input numpy array to a PIL Image and ensure it has 3 channels (RGB)
|
| 306 |
+
image = Image.fromarray(image).convert("RGB")
|
| 307 |
|
| 308 |
+
# Process the image and prepare it for the model
|
| 309 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 310 |
+
|
| 311 |
+
# Perform inference without gradient calculation
|
| 312 |
with torch.no_grad():
|
| 313 |
outputs = model(**inputs)
|
| 314 |
logits = outputs.logits
|
| 315 |
+
# Convert logits to probabilities using softmax
|
| 316 |
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
|
| 317 |
|
| 318 |
+
# Mapping from indices to corresponding sketch category labels
|
| 319 |
labels = {
|
| 320 |
"0": "aircraft_carrier", "1": "alarm_clock", "2": "ant", "3": "anvil", "4": "asparagus",
|
| 321 |
"5": "axe", "6": "banana", "7": "basket", "8": "bathtub", "9": "bear",
|
|
|
|
| 345 |
"122": "vase", "123": "watermelon", "124": "whale", "125": "zebra"
|
| 346 |
}
|
| 347 |
|
| 348 |
+
# Create a dictionary mapping each label to its predicted probability (rounded)
|
| 349 |
predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
|
| 350 |
return predictions
|
| 351 |
|
| 352 |
# Create Gradio interface
|
| 353 |
iface = gr.Interface(
|
| 354 |
fn=sketch_classification,
|
| 355 |
+
inputs=gr.Image(type="numpy"),
|
| 356 |
+
outputs=gr.Label(label="Prediction Scores"),
|
| 357 |
+
title="Sketch-126-DomainNet Classification",
|
| 358 |
+
description="Upload a sketch to classify it into one of 126 categories."
|
| 359 |
)
|
| 360 |
|
| 361 |
# Launch the app
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
iface.launch()
|
| 364 |
```
|
| 365 |
|