lexandstuff commited on
Commit
d3d9554
·
verified ·
1 Parent(s): cce74da

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +11 -2
README.md CHANGED
@@ -27,29 +27,38 @@ pip install mlx-image
27
  Here is how to use this model for image classification:
28
 
29
  ```python
 
30
  from mlxim.model import create_model
31
  from mlxim.io import read_rgb
32
  from mlxim.transform import ImageNetTransform
 
33
 
34
  transform = ImageNetTransform(train=False, img_size=224)
35
- x = transform(read_rgb("cat.png"))
 
36
  x = mx.expand_dims(x, 0)
37
 
38
  model = create_model("efficientnet_b0")
39
  model.eval()
40
 
41
  logits = model(x)
 
 
 
 
42
  ```
43
 
44
  You can also use the embeds from layer before head:
45
 
46
  ```python
 
47
  from mlxim.model import create_model
48
  from mlxim.io import read_rgb
49
  from mlxim.transform import ImageNetTransform
50
 
51
  transform = ImageNetTransform(train=False, img_size=224)
52
- x = transform(read_rgb("cat.png"))
 
53
  x = mx.expand_dims(x, 0)
54
 
55
  # first option
 
27
  Here is how to use this model for image classification:
28
 
29
  ```python
30
+ import mlx.core as mx
31
  from mlxim.model import create_model
32
  from mlxim.io import read_rgb
33
  from mlxim.transform import ImageNetTransform
34
+ from mlxim.utils.imagenet import IMAGENET2012_CLASSES
35
 
36
  transform = ImageNetTransform(train=False, img_size=224)
37
+ x = transform(read_rgb("cat.jpg"))
38
+ x = mx.array(x)
39
  x = mx.expand_dims(x, 0)
40
 
41
  model = create_model("efficientnet_b0")
42
  model.eval()
43
 
44
  logits = model(x)
45
+ predicted_idx = mx.argmax(logits, axis=-1).item()
46
+ predicted_class = list(IMAGENET2012_CLASSES.values())[predicted_idx]
47
+
48
+ print(f"Predicted class: {predicted_class}")
49
  ```
50
 
51
  You can also use the embeds from layer before head:
52
 
53
  ```python
54
+ import mlx.core as mx
55
  from mlxim.model import create_model
56
  from mlxim.io import read_rgb
57
  from mlxim.transform import ImageNetTransform
58
 
59
  transform = ImageNetTransform(train=False, img_size=224)
60
+ x = transform(read_rgb("cat.jpg"))
61
+ x = mx.array(x)
62
  x = mx.expand_dims(x, 0)
63
 
64
  # first option