How to use from the
Use from the
sam2 library
# Use SAM2 with images
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor

predictor = SAM2ImagePredictor.from_pretrained(mlx-community/EdgeTAM-fp16)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(<your_image>)
    masks, _, _ = predictor.predict(<input_prompts>)
# Use SAM2 with videos
import torch
from sam2.sam2_video_predictor import SAM2VideoPredictor

predictor = SAM2VideoPredictor.from_pretrained(mlx-community/EdgeTAM-fp16)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    state = predictor.init_state(<your_video>)

    # add new prompts and instantly get the output on the same frame
    frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):

    # propagate the prompts to get masklets throughout the video
    for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
        ...

mlx-community/EdgeTAM-fp16

EdgeTAM β€” on-device SAM 2 for promptable segmentation + video tracking β€” converted to Apple MLX (-fp16) for the mlx-edgetam-swift Swift package (MLXEngine promptSegment + trackObject ModelPackage). 22Γ— faster than SAM 2, 16 FPS on iPhone 15 Pro Max.

From-scratch MLX-Swift architecture port. Image-mode (point/box β†’ mask): RepViT-M1 encoder + FPN + SAM prompt encoder + two-way mask decoder β€” parity-locked vs the PyTorch oracle on the CPU stream (image_embed 9.7e-6, mask logits 8.9e-5; end-to-end mask IoU 0.99 vs PyTorch). Video-mode (trackObject, click on one frame β†’ per-frame masklet): adds the video memory stack β€” PerceiverResampler

  • MemoryEncoder + MemoryAttention (RoPE-2D) + the SAM2 memory-bank state machine β€” every op parity-locked vs the oracle; full masklet propagation min-IoU 0.92. This single -fp16 file carries both (874 tensors).

Use

// Package.swift β†’ .package(url: "https://github.com/xocialize/mlx-edgetam-swift", from: "0.1.0")
import EdgeTAM
// Image: click β†’ object mask
let p = try EdgeTAMPredictor.fromPretrained(weightsPath, dtype: .float16)
p.setImage(sourceCGImage)
let (mask, score, _, _) = p.predict(point: (500, 375))
// Video: click on a frame β†’ per-frame masklet
let vp = try EdgeTAMVideoPredictor.fromPretrained(weightsPath, dtype: .float16)
let track = vp.track(frames: cgImages, clickFrame: 0, points: [[210, 350]], labels: [1])

Or as an MLXEngine ModelPackage (MLXEdgeTAM.EdgeTAMPackage) β€” promptSegment (image) + trackObject (video) surfaces β€” resolving this repo via the Hub.

Weights: Apache-2.0 (facebookresearch/EdgeTAM). Port code: MIT.

Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
13.9M params
Tensor type
F16
Β·
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support