Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| import logging | |
| import urllib.request | |
| from diffusers import FluxFillPipeline, FluxTransformer2DModel | |
| from segment_anything import sam_model_registry, SamPredictor | |
| from segment_anything import build_sam | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| logger = logging.getLogger(__name__) | |
| sam_ckpt = hf_hub_download( | |
| repo_id="SnapwearAI/sam_model", | |
| filename="sam_vit_h_4b8939.pth", | |
| ) | |
| def get_sam_predictor(): | |
| sam = build_sam(checkpoint=sam_ckpt) | |
| sam.to(device=device) | |
| predictor = SamPredictor(sam) | |
| return predictor | |
| def get_flux_pipeline(): | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| "SnapwearAI/bg-transformer", | |
| subfolder="transformer", # <-- tell HF to look inside transformer/ | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe_flux = FluxFillPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-Fill-dev", | |
| transformer=transformer, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| return pipe_flux | |
| OUTPUT_DIR = "outputs" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| sam_predictor = None | |
| def download_models(): | |
| """Download models from official sources.""" | |
| # Create models directory | |
| os.makedirs("models", exist_ok=True) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SAM MODEL - Download from Facebook's official release | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sam_path = "models/sam_vit_h_4b8939.pth" | |
| if not os.path.exists(sam_path): | |
| logger.info("π₯ Downloading SAM model from Facebook (2.6GB)...") | |
| try: | |
| urllib.request.urlretrieve( | |
| "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", | |
| sam_path | |
| ) | |
| logger.info(f"β SAM model downloaded to: {sam_path}") | |
| except Exception as e: | |
| logger.error(f"β Failed to download SAM model: {e}") | |
| raise | |
| else: | |
| logger.info(f"β SAM model already exists: {sam_path}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GROUNDING DINO MODEL - Download from GitHub releases | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| grounding_path = "models/groundingdino_swint_ogc.pth" | |
| if not os.path.exists(grounding_path): | |
| logger.info("π₯ Downloading GroundingDINO model from GitHub (694MB)...") | |
| try: | |
| urllib.request.urlretrieve( | |
| "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", | |
| grounding_path | |
| ) | |
| logger.info(f"β GroundingDINO model downloaded to: {grounding_path}") | |
| except Exception as e: | |
| logger.error(f"β Failed to download GroundingDINO model: {e}") | |
| raise | |
| else: | |
| logger.info(f"β GroundingDINO model already exists: {grounding_path}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG FILE - Create if doesn't exist | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| config_path = "models/GroundingDINO_SwinT_OGC.py" | |
| if not os.path.exists(config_path): | |
| logger.info("π Creating GroundingDINO config file...") | |
| # Minimal config that works with GroundingDINO | |
| config_content = '''import os.path as osp | |
| import sys | |
| # Add current directory to path for imports | |
| sys.path.insert(0, osp.dirname(__file__)) | |
| # Model configuration | |
| model = dict( | |
| type='GroundingDINO', | |
| num_queries=900, | |
| with_box_refine=True, | |
| as_two_stage=True, | |
| data_preprocessor=dict( | |
| type='DetDataPreprocessor', | |
| mean=[123.675, 116.28, 103.53], | |
| std=[58.395, 57.12, 57.375], | |
| bgr_to_rgb=True, | |
| pad_mask=False, | |
| ), | |
| language_model=dict( | |
| type='BertModel', | |
| name='bert-base-uncased', | |
| max_tokens=256, | |
| pad_to_max=False, | |
| use_sub_sentence_represent=True, | |
| special_tokens_list=["[CLS]", "[SEP]", ".", "?"], | |
| add_pooling_layer=False, | |
| ), | |
| backbone=dict( | |
| type='SwinTransformer', | |
| pretrain_img_size=384, | |
| embed_dims=96, | |
| depths=[2, 2, 6, 2], | |
| num_heads=[3, 6, 12, 24], | |
| window_size=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop_rate=0., | |
| attn_drop_rate=0., | |
| drop_path_rate=0.2, | |
| patch_norm=True, | |
| out_indices=(1, 2, 3), | |
| with_cp=True, | |
| convert_weights=True, | |
| ), | |
| neck=dict( | |
| type='ChannelMapper', | |
| in_channels=[192, 384, 768], | |
| kernel_size=1, | |
| out_channels=256, | |
| act_cfg=None, | |
| norm_cfg=dict(type='GN', num_groups=32), | |
| num_outs=4), | |
| encoder=dict( | |
| type='DetrTransformerEncoder', | |
| num_layers=6, | |
| transformerlayers=dict( | |
| type='BaseTransformerLayer', | |
| attn_cfgs=dict( | |
| type='MultiScaleDeformableAttention', | |
| embed_dims=256, | |
| num_heads=8, | |
| num_levels=4, | |
| num_points=4, | |
| im2col_step=64, | |
| dropout=0.0, | |
| batch_first=False, | |
| norm_cfg=None, | |
| init_cfg=None), | |
| feedforward_channels=2048, | |
| ffn_dropout=0.0, | |
| operation_order=('self_attn', 'norm', 'ffn', 'norm'))), | |
| decoder=dict( | |
| type='GroundingDINOTransformerDecoder', | |
| num_layers=6, | |
| return_intermediate=True, | |
| transformerlayers=dict( | |
| type='GroundingDINOTransformerDecoderLayer', | |
| attn_cfgs=[ | |
| dict( | |
| type='MultiheadAttention', | |
| embed_dims=256, | |
| num_heads=8, | |
| dropout=0.0, | |
| batch_first=False), | |
| dict( | |
| type='MultiScaleDeformableAttention', | |
| embed_dims=256, | |
| num_heads=8, | |
| num_levels=4, | |
| num_points=4, | |
| im2col_step=64, | |
| dropout=0.0, | |
| batch_first=False, | |
| norm_cfg=None, | |
| init_cfg=None) | |
| ], | |
| feedforward_channels=2048, | |
| ffn_dropout=0.0, | |
| operation_order=('self_attn', 'norm', 'cross_attn', 'norm', | |
| 'ffn', 'norm'))), | |
| positional_encoding=dict( | |
| type='SinePositionalEncoding', | |
| num_feats=128, | |
| normalize=True, | |
| offset=-0.5), | |
| bbox_head=dict( | |
| type='GroundingDINOHead', | |
| num_queries=900, | |
| num_classes=256, | |
| in_channels=2048, | |
| sync_cls_avg_factor=True, | |
| as_two_stage=True, | |
| with_box_refine=True, | |
| dn_cfg=dict( | |
| type='CdnQueryGenerator', | |
| noise_scale=dict(label=0.5, box=1.0), | |
| group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)), | |
| transformer=dict( | |
| type='GroundingDINOTransformer', | |
| embed_dims=256, | |
| num_feature_levels=4, | |
| encoder=dict( | |
| type='DetrTransformerEncoder', | |
| num_layers=6, | |
| transformerlayers=dict( | |
| type='BaseTransformerLayer', | |
| attn_cfgs=dict( | |
| type='MultiScaleDeformableAttention', | |
| embed_dims=256, | |
| num_heads=8, | |
| num_levels=4, | |
| num_points=4, | |
| im2col_step=64, | |
| dropout=0.0, | |
| batch_first=False, | |
| norm_cfg=None, | |
| init_cfg=None), | |
| feedforward_channels=2048, | |
| ffn_dropout=0.0, | |
| operation_order=('self_attn', 'norm', 'ffn', 'norm'))), | |
| decoder=dict( | |
| type='GroundingDINOTransformerDecoder', | |
| num_layers=6, | |
| return_intermediate=True, | |
| transformerlayers=dict( | |
| type='GroundingDINOTransformerDecoderLayer', | |
| attn_cfgs=[ | |
| dict( | |
| type='MultiheadAttention', | |
| embed_dims=256, | |
| num_heads=8, | |
| dropout=0.0, | |
| batch_first=False), | |
| dict( | |
| type='MultiScaleDeformableAttention', | |
| embed_dims=256, | |
| num_heads=8, | |
| num_levels=4, | |
| num_points=4, | |
| im2col_step=64, | |
| dropout=0.0, | |
| batch_first=False, | |
| norm_cfg=None, | |
| init_cfg=None) | |
| ], | |
| feedforward_channels=2048, | |
| ffn_dropout=0.0, | |
| operation_order=('self_attn', 'norm', 'cross_attn', 'norm', | |
| 'ffn', 'norm'))), | |
| positional_encoding=dict( | |
| type='SinePositionalEncoding', | |
| num_feats=128, | |
| normalize=True, | |
| offset=-0.5)), | |
| loss_cls=dict( | |
| type='FocalLoss', | |
| use_sigmoid=True, | |
| gamma=2.0, | |
| alpha=0.25, | |
| loss_weight=1.0), | |
| loss_bbox=dict(type='L1Loss', loss_weight=5.0), | |
| loss_iou=dict(type='GIoULoss', loss_weight=2.0)), | |
| dn_cfg=dict( | |
| label_noise_scale=0.5, | |
| box_noise_scale=1.0, | |
| group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)), | |
| # training and testing settings | |
| train_cfg=dict( | |
| assigner=dict( | |
| type='HungarianAssigner', | |
| cls_cost=dict(type='FocalLossCost', weight=2.0), | |
| reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), | |
| iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))), | |
| test_cfg=dict(max_per_img=300)) | |
| # Dataset settings | |
| dataset_type = 'CocoDataset' | |
| data_root = 'data/coco/' | |
| ''' | |
| with open(config_path, 'w') as f: | |
| f.write(config_content) | |
| logger.info(f"β Config file created at: {config_path}") | |
| else: | |
| logger.info(f"β Config already exists: {config_path}") | |
| return sam_path, grounding_path, config_path | |
| def initialize_pipeline(): | |
| """Initialize GroundingDINO and SAM models.""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"π Using device: {device}") | |
| try: | |
| # Download models | |
| sam_path, grounding_path, config_path = download_models() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Initialize GroundingDINO | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("π§ Loading GroundingDINO model...") | |
| # Import here to avoid issues if not installed | |
| from groundingdino.util.inference import Model | |
| grounding_dino_model = Model( | |
| model_config_path=config_path, | |
| model_checkpoint_path=grounding_path, | |
| device=device | |
| ) | |
| logger.info("β GroundingDINO loaded successfully!") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Initialize SAM | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("π§ Loading SAM model...") | |
| sam = sam_model_registry["vit_h"](checkpoint=sam_path) | |
| sam.to(device=device) | |
| sam_predictor = SamPredictor(sam) | |
| logger.info("β SAM loaded successfully!") | |
| logger.info("π All models initialized successfully!") | |
| return { | |
| "grounding_dino": grounding_dino_model, | |
| "sam_predictor": sam_predictor, | |
| "device": device | |
| } | |
| except Exception as e: | |
| logger.error(f"β Failed to initialize models: {e}") | |
| import traceback | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| raise | |