--- library_name: mlx base_model: facebook/sam3.1 tags: - mlx - sam3 - sam3.1 - segmentation - detection - tracking - object-multiplex --- # sam3.1-bf16 This model was converted to MLX format from [`facebook/sam3.1`](https://huggingface.co/facebook/sam3.1) using mlx-vlm version **0.4.3**. Open-vocabulary **object detection**, **instance segmentation**, and **video tracking** with **Object Multiplex** on Apple Silicon (~873M parameters). SAM 3.1 extends SAM 3 with: - **MultiplexMaskDecoder**: processes 16 objects simultaneously (2.4-4x faster tracking) - **TriViTDetNeck**: 3 parallel FPN heads (detection, interactive, propagation) - **DecoupledMemoryAttention**: image cross-attention with RoPE - Improved detection accuracy (0.90 vs 0.87 on cats benchmark) ## Quick Start ```bash pip install mlx-vlm>=0.4.3 ``` ```python from PIL import Image from mlx_vlm.utils import load_model, get_model_path from mlx_vlm.models.sam3.generate import Sam3Predictor from mlx_vlm.models.sam3_1.processing_sam3_1 import Sam31Processor model_path = get_model_path("mlx-community/sam3.1-bf16") model = load_model(model_path) processor = Sam31Processor.from_pretrained(str(model_path)) predictor = Sam3Predictor(model, processor, score_threshold=0.3) ``` ## Object Detection ```python image = Image.open("photo.jpg") result = predictor.predict(image, text_prompt="a dog") for i in range(len(result.scores)): x1, y1, x2, y2 = result.boxes[i] print(f"[{result.scores[i]:.2f}] box=({x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f})") ``` ## Instance Segmentation ```python result = predictor.predict(image, text_prompt="a person") # result.boxes -> (N, 4) xyxy bounding boxes # result.masks -> (N, H, W) binary segmentation masks # result.scores -> (N,) confidence scores import numpy as np overlay = np.array(image).copy() W, H = image.size for i in range(len(result.scores)): mask = result.masks[i] if mask.shape != (H, W): mask = np.array(Image.fromarray(mask.astype(np.float32)).resize((W, H))) binary = mask > 0 overlay[binary] = (overlay[binary] * 0.5 + np.array([255, 0, 0]) * 0.5).astype(np.uint8) ``` ## Multi-Prompt Detection ```python from mlx_vlm.models.sam3_1.generate import predict_multi result = predict_multi(predictor, image, ["a cat", "a remote control"]) for i in range(len(result.scores)): x1, y1, x2, y2 = result.boxes[i] print(f"[{result.scores[i]:.2f}] {result.labels[i]} box=({x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f})") ``` ## Box-Guided Detection ```python import numpy as np boxes = np.array([[100, 50, 400, 350]]) # xyxy pixel coords result = predictor.predict(image, text_prompt="a cat", boxes=boxes) ``` ## CLI ```bash # Object detection python -m mlx_vlm.models.sam3_1.generate --task detect --image photo.jpg --prompt "a cat" --model mlx-community/sam3.1-bf16 # Instance segmentation python -m mlx_vlm.models.sam3_1.generate --image photo.jpg --prompt "a cat" --model mlx-community/sam3.1-bf16 # Video tracking python -m mlx_vlm.models.sam3_1.generate --task track --video input.mp4 --prompt "a car" --model mlx-community/sam3.1-bf16 # Real-time webcam (optimized: backbone caching + tracker propagation) python -m mlx_vlm.models.sam3_1.generate --task realtime --prompt "a person" --model mlx-community/sam3.1-bf16 --resolution 224 ``` | Flag | Default | Description | |------|---------|-------------| | `--task` | `segment` | `detect`, `segment`, `track`, `realtime` | | `--prompt` | *(required)* | Text prompt(s), supports multiple | | `--resolution` | `1008` | Input resolution (224 for faster realtime) | | `--detect-every` | `15` | Re-run full detection every N frames | | `--backbone-every` | `30` | Re-run ViT backbone every N frames | ## Benchmarks (M3 Max, bf16) ### Detection Accuracy | Prompt | SAM 3 | SAM 3.1 | |--------|-------|---------| | "a cat" (2 cats) | 0.87, 0.82 | **0.90, 0.86** | | "a remote control" | 0.95, 0.94 | 0.94, 0.94 | ### Tracker Multiplex Speed | Objects | SAM 3 | SAM 3.1 | Speedup | |---------|-------|---------|---------| | 3 | 547ms/frame | 227ms/frame | **2.4x** | | 4 | 608ms/frame | 203ms/frame | **3.0x** | | 5 | 766ms/frame | 190ms/frame | **4.0x** | ### Optimized Realtime (224px) | Metric | Value | |--------|-------| | Cached frame | 38ms (26 FPS) | | Sustained average | ~40ms (25 FPS) | | Baseline (no optimization) | ~212ms (5 FPS) | | **Total speedup** | **4.6x** | ## Original Model [facebook/sam3.1](https://huggingface.co/facebook/sam3.1) ยท [Code](https://github.com/facebookresearch/sam3) ## License The original SAM 3.1 model weights are released by Meta under the [SAM License](https://huggingface.co/facebook/sam3.1/blob/main/LICENSE), a custom permissive license for commercial and research use.