Saks-backend-new / models_init.py
BoooomNing's picture
Upload folder using huggingface_hub
aa3311b verified
import torch
import os
import logging
import urllib.request
from diffusers import FluxFillPipeline, FluxTransformer2DModel
from segment_anything import sam_model_registry, SamPredictor
from segment_anything import build_sam
import spaces
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
sam_ckpt = hf_hub_download(
repo_id="SnapwearAI/sam_model",
filename="sam_vit_h_4b8939.pth",
)
@spaces.GPU
def get_sam_predictor():
sam = build_sam(checkpoint=sam_ckpt)
sam.to(device=device)
predictor = SamPredictor(sam)
return predictor
@spaces.GPU
def get_flux_pipeline():
transformer = FluxTransformer2DModel.from_pretrained(
"SnapwearAI/bg-transformer",
subfolder="transformer", # <-- tell HF to look inside transformer/
torch_dtype=torch.bfloat16
)
pipe_flux = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
return pipe_flux
OUTPUT_DIR = "outputs"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam_predictor = None
def download_models():
"""Download models from official sources."""
# Create models directory
os.makedirs("models", exist_ok=True)
# ═══════════════════════════════════════════════════════════════
# SAM MODEL - Download from Facebook's official release
# ═══════════════════════════════════════════════════════════════
sam_path = "models/sam_vit_h_4b8939.pth"
if not os.path.exists(sam_path):
logger.info("πŸ“₯ Downloading SAM model from Facebook (2.6GB)...")
try:
urllib.request.urlretrieve(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
sam_path
)
logger.info(f"βœ… SAM model downloaded to: {sam_path}")
except Exception as e:
logger.error(f"❌ Failed to download SAM model: {e}")
raise
else:
logger.info(f"βœ… SAM model already exists: {sam_path}")
# ═══════════════════════════════════════════════════════════════
# GROUNDING DINO MODEL - Download from GitHub releases
# ═══════════════════════════════════════════════════════════════
grounding_path = "models/groundingdino_swint_ogc.pth"
if not os.path.exists(grounding_path):
logger.info("πŸ“₯ Downloading GroundingDINO model from GitHub (694MB)...")
try:
urllib.request.urlretrieve(
"https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
grounding_path
)
logger.info(f"βœ… GroundingDINO model downloaded to: {grounding_path}")
except Exception as e:
logger.error(f"❌ Failed to download GroundingDINO model: {e}")
raise
else:
logger.info(f"βœ… GroundingDINO model already exists: {grounding_path}")
# ═══════════════════════════════════════════════════════════════
# CONFIG FILE - Create if doesn't exist
# ═══════════════════════════════════════════════════════════════
config_path = "models/GroundingDINO_SwinT_OGC.py"
if not os.path.exists(config_path):
logger.info("πŸ“ Creating GroundingDINO config file...")
# Minimal config that works with GroundingDINO
config_content = '''import os.path as osp
import sys
# Add current directory to path for imports
sys.path.insert(0, osp.dirname(__file__))
# Model configuration
model = dict(
type='GroundingDINO',
num_queries=900,
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=False,
),
language_model=dict(
type='BertModel',
name='bert-base-uncased',
max_tokens=256,
pad_to_max=False,
use_sub_sentence_represent=True,
special_tokens_list=["[CLS]", "[SEP]", ".", "?"],
add_pooling_layer=False,
),
backbone=dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
with_cp=True,
convert_weights=True,
),
neck=dict(
type='ChannelMapper',
in_channels=[192, 384, 768],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
feedforward_channels=2048,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='GroundingDINOTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='GroundingDINOTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=False),
dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None)
],
feedforward_channels=2048,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
bbox_head=dict(
type='GroundingDINOHead',
num_queries=900,
num_classes=256,
in_channels=2048,
sync_cls_avg_factor=True,
as_two_stage=True,
with_box_refine=True,
dn_cfg=dict(
type='CdnQueryGenerator',
noise_scale=dict(label=0.5, box=1.0),
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
transformer=dict(
type='GroundingDINOTransformer',
embed_dims=256,
num_feature_levels=4,
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None),
feedforward_channels=2048,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='GroundingDINOTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='GroundingDINOTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.0,
batch_first=False),
dict(
type='MultiScaleDeformableAttention',
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None)
],
feedforward_channels=2048,
ffn_dropout=0.0,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5)),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict(
label_noise_scale=0.5,
box_noise_scale=1.0,
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))),
test_cfg=dict(max_per_img=300))
# Dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
'''
with open(config_path, 'w') as f:
f.write(config_content)
logger.info(f"βœ… Config file created at: {config_path}")
else:
logger.info(f"βœ… Config already exists: {config_path}")
return sam_path, grounding_path, config_path
def initialize_pipeline():
"""Initialize GroundingDINO and SAM models."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"πŸš€ Using device: {device}")
try:
# Download models
sam_path, grounding_path, config_path = download_models()
# ═══════════════════════════════════════════════════════════════
# Initialize GroundingDINO
# ═══════════════════════════════════════════════════════════════
logger.info("πŸ”§ Loading GroundingDINO model...")
# Import here to avoid issues if not installed
from groundingdino.util.inference import Model
grounding_dino_model = Model(
model_config_path=config_path,
model_checkpoint_path=grounding_path,
device=device
)
logger.info("βœ… GroundingDINO loaded successfully!")
# ═══════════════════════════════════════════════════════════════
# Initialize SAM
# ═══════════════════════════════════════════════════════════════
logger.info("πŸ”§ Loading SAM model...")
sam = sam_model_registry["vit_h"](checkpoint=sam_path)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
logger.info("βœ… SAM loaded successfully!")
logger.info("πŸŽ‰ All models initialized successfully!")
return {
"grounding_dino": grounding_dino_model,
"sam_predictor": sam_predictor,
"device": device
}
except Exception as e:
logger.error(f"❌ Failed to initialize models: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise