|
|
package examples |
|
|
|
|
|
import ( |
|
|
"fmt" |
|
|
"github.com/getcharzp/go-vision/sam2" |
|
|
"github.com/up-zero/gotool/imageutil" |
|
|
_ "image/jpeg" |
|
|
"testing" |
|
|
) |
|
|
|
|
|
func TestSAM2Refactored(t *testing.T) { |
|
|
config := sam2.Config{ |
|
|
OnnxRuntimeLibPath: "../lib/onnxruntime.dll", |
|
|
EncodeModelPath: "../sam2_weights/vision_encoder.onnx", |
|
|
DecodeModelPath: "../sam2_weights/prompt_encoder_mask_decoder.onnx", |
|
|
} |
|
|
|
|
|
engine, err := sam2.NewEngine(config) |
|
|
if err != nil { |
|
|
t.Fatalf("初始化引擎失败: %v", err) |
|
|
} |
|
|
defer engine.Destroy() |
|
|
|
|
|
img, _ := imageutil.Open("./test.png") |
|
|
imgCtx, err := engine.EncodeImage(img) |
|
|
if err != nil { |
|
|
t.Fatalf("图片 Encode 失败: %v", err) |
|
|
} |
|
|
defer imgCtx.Destroy() |
|
|
|
|
|
points := []sam2.Point{ |
|
|
{X: 367, Y: 168, Label: sam2.LabelBoxTopLeft}, |
|
|
{X: 441, Y: 349, Label: sam2.LabelBoxBotRight}, |
|
|
} |
|
|
imgResult, score, err := imgCtx.Decode(points) |
|
|
if err != nil { |
|
|
t.Fatalf("Mask Decode 失败: %v", err) |
|
|
} |
|
|
|
|
|
fmt.Printf("Mask generated, score: %.4f\n", score) |
|
|
imageutil.Save("output_mask.png", imgResult, 100) |
|
|
} |
|
|
|