Image Classification
Transformers
Safetensors
English
siglip
Sketch-126-DomainNet
prithivMLmods commited on
Commit
e4f15b4
·
verified ·
1 Parent(s): 197059a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -9
README.md CHANGED
@@ -290,8 +290,7 @@ The model categorizes images into the following 126 classes:
290
 
291
  ```python
292
  import gradio as gr
293
- from transformers import AutoImageProcessor
294
- from transformers import SiglipForImageClassification
295
  from transformers.image_utils import load_image
296
  from PIL import Image
297
  import torch
@@ -302,14 +301,21 @@ model = SiglipForImageClassification.from_pretrained(model_name)
302
  processor = AutoImageProcessor.from_pretrained(model_name)
303
 
304
  def sketch_classification(image):
305
- \"\"\"Predicts the sketch category for an input image.\"\"\n image = Image.fromarray(image).convert(\"RGB\")
306
- inputs = processor(images=image, return_tensors=\"pt\")
 
307
 
 
 
 
 
308
  with torch.no_grad():
309
  outputs = model(**inputs)
310
  logits = outputs.logits
 
311
  probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
312
 
 
313
  labels = {
314
  "0": "aircraft_carrier", "1": "alarm_clock", "2": "ant", "3": "anvil", "4": "asparagus",
315
  "5": "axe", "6": "banana", "7": "basket", "8": "bathtub", "9": "bear",
@@ -339,20 +345,21 @@ def sketch_classification(image):
339
  "122": "vase", "123": "watermelon", "124": "whale", "125": "zebra"
340
  }
341
 
 
342
  predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
343
  return predictions
344
 
345
  # Create Gradio interface
346
  iface = gr.Interface(
347
  fn=sketch_classification,
348
- inputs=gr.Image(type=\"numpy\"),
349
- outputs=gr.Label(label=\"Prediction Scores\"),
350
- title=\"Sketch-126-DomainNet Classification\",
351
- description=\"Upload a sketch to classify it into one of 126 categories.\"
352
  )
353
 
354
  # Launch the app
355
- if __name__ == \"__main__\":
356
  iface.launch()
357
  ```
358
 
 
290
 
291
  ```python
292
  import gradio as gr
293
+ from transformers import AutoImageProcessor, SiglipForImageClassification
 
294
  from transformers.image_utils import load_image
295
  from PIL import Image
296
  import torch
 
301
  processor = AutoImageProcessor.from_pretrained(model_name)
302
 
303
  def sketch_classification(image):
304
+ """Predicts the sketch category for an input image."""
305
+ # Convert the input numpy array to a PIL Image and ensure it has 3 channels (RGB)
306
+ image = Image.fromarray(image).convert("RGB")
307
 
308
+ # Process the image and prepare it for the model
309
+ inputs = processor(images=image, return_tensors="pt")
310
+
311
+ # Perform inference without gradient calculation
312
  with torch.no_grad():
313
  outputs = model(**inputs)
314
  logits = outputs.logits
315
+ # Convert logits to probabilities using softmax
316
  probs = torch.nn.functional.softmax(logits, dim=1).squeeze().tolist()
317
 
318
+ # Mapping from indices to corresponding sketch category labels
319
  labels = {
320
  "0": "aircraft_carrier", "1": "alarm_clock", "2": "ant", "3": "anvil", "4": "asparagus",
321
  "5": "axe", "6": "banana", "7": "basket", "8": "bathtub", "9": "bear",
 
345
  "122": "vase", "123": "watermelon", "124": "whale", "125": "zebra"
346
  }
347
 
348
+ # Create a dictionary mapping each label to its predicted probability (rounded)
349
  predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
350
  return predictions
351
 
352
  # Create Gradio interface
353
  iface = gr.Interface(
354
  fn=sketch_classification,
355
+ inputs=gr.Image(type="numpy"),
356
+ outputs=gr.Label(label="Prediction Scores"),
357
+ title="Sketch-126-DomainNet Classification",
358
+ description="Upload a sketch to classify it into one of 126 categories."
359
  )
360
 
361
  # Launch the app
362
+ if __name__ == "__main__":
363
  iface.launch()
364
  ```
365