| """
|
| SAM3 Inference Script for Lesion Segmentation (Text Prompt)
|
| ===========================================================
|
|
|
| This script runs inference using the SAM3 model with TEXT PROMPTS.
|
| It is designed to batch process a folder of images.
|
|
|
| Usage:
|
| - Define `TEXT_PROMPT` in the Configuration (e.g., "orange spot", "lesion").
|
| - The model automatically segments areas matching the text description.
|
| - Output is a BINARY mask (0=Background, 255=Target).
|
|
|
| Requirements:
|
| - torch
|
| - PIL (Pillow)
|
| - numpy
|
| - sam3 (custom library)
|
| - huggingface_hub
|
|
|
| Prerequisites (SAM3 Specific):
|
| * **Hugging Face Access**: You must have a Hugging Face account and valid access permissions for the SAM3 model.
|
| * **Login**: You must be logged in locally to download/access the model weights.
|
| Run the following command in your terminal and paste your access token:
|
| `huggingface-cli login`
|
| """
|
|
|
| import os
|
| from PIL import Image
|
| import numpy as np
|
| from tqdm import tqdm
|
| import torch
|
|
|
|
|
| try:
|
| from huggingface_hub import login
|
| except ImportError:
|
| pass
|
|
|
|
|
| try:
|
| from sam3.model_builder import build_sam3_image_model
|
| from sam3.model.sam3_image_processor import Sam3Processor
|
| except ImportError:
|
| print("Error: Could not import 'sam3'. Ensure the submodule is present.")
|
|
|
|
|
|
|
| INPUT_DIR = "./data/images"
|
|
|
|
|
| OUTPUT_DIR = "./data/inference_sam3"
|
|
|
|
|
| TEXT_PROMPT = "orange spot"
|
|
|
|
|
| VALID_EXT = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
|
|
|
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
| def run_inference():
|
| print(f"=== SAM3 Text-Prompt Inference ===")
|
| print(f"Device: {DEVICE}")
|
| print(f"Prompt: '{TEXT_PROMPT}'")
|
|
|
|
|
| if not os.path.exists(INPUT_DIR):
|
| print(f"Error: Input directory not found: {INPUT_DIR}")
|
| return
|
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
| print("Loading SAM3 model (might download weights if first time)...")
|
| try:
|
| model = build_sam3_image_model(device=DEVICE)
|
| processor = Sam3Processor(model)
|
| except Exception as e:
|
| print(f"Error initializing SAM3: {e}")
|
| print("Tip: Make sure you have run 'huggingface-cli login'")
|
| return
|
|
|
|
|
| image_files = [
|
| f for f in os.listdir(INPUT_DIR)
|
| if f.lower().endswith(VALID_EXT)
|
| ]
|
| print(f"Found {len(image_files)} images to process.")
|
|
|
|
|
| processed_count = 0
|
|
|
| for fname in tqdm(image_files, desc="Segmenting"):
|
| img_path = os.path.join(INPUT_DIR, fname)
|
|
|
| try:
|
| image = Image.open(img_path).convert("RGB")
|
| except Exception as e:
|
| print(f"Skipping unreadable file {fname}: {e}")
|
| continue
|
|
|
|
|
| state = processor.set_image(image)
|
|
|
|
|
| output = processor.set_text_prompt(
|
| state=state,
|
| prompt=TEXT_PROMPT
|
| )
|
|
|
| masks = output.get("masks", None)
|
|
|
| if masks is None or len(masks) == 0:
|
|
|
|
|
| width, height = image.size
|
| combined = np.zeros((height, width), dtype=np.uint8)
|
| else:
|
|
|
| masks_np = []
|
| for m in masks:
|
| m_np = m.detach().cpu().numpy()
|
|
|
|
|
| m_np = np.squeeze(m_np)
|
|
|
|
|
| m_np = (m_np > 0).astype(np.uint8)
|
| masks_np.append(m_np)
|
|
|
|
|
| combined = np.zeros_like(masks_np[0], dtype=np.uint8)
|
| for m in masks_np:
|
| combined = np.maximum(combined, m)
|
|
|
|
|
| out_name = os.path.splitext(fname)[0] + ".png"
|
| out_path = os.path.join(OUTPUT_DIR, out_name)
|
|
|
| Image.fromarray(combined * 255).save(out_path)
|
| processed_count += 1
|
|
|
| print(f"✅ Inference complete. Processed {processed_count} images.")
|
| print(f"Results saved to: {OUTPUT_DIR}")
|
|
|
| if __name__ == "__main__":
|
| run_inference()
|
|
|