SeeMODEL / architecture_detector.py
priyadip's picture
Add app.py and fix short_description YAML metadata
297ccd3
"""
🧬 Architecture Detector β€” Identifies 150+ model families from state_dict keys.
Each entry: (pattern_in_keys, architecture_name, category, description)
"""
from collections import defaultdict, OrderedDict
import re
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# MASTER PATTERN DATABASE
# Format: (key_pattern_regex, model_family, category, description)
# Categories: NLP, Vision, Audio, Multimodal, Generative, RL, Science, Other
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
ARCH_PATTERNS = [
# ─────────────────── NLP: Encoder Models ───────────────────
(r"bert\.encoder\.layer\.\d+\.attention", "BERT", "NLP β€” Encoder", "Bidirectional Encoder Representations from Transformers"),
(r"bert\.embeddings\.word_embeddings", "BERT", "NLP β€” Encoder", "BERT base/large"),
(r"roberta\.encoder\.layer\.\d+", "RoBERTa", "NLP β€” Encoder", "Robustly Optimized BERT"),
(r"distilbert\.transformer\.layer\.\d+", "DistilBERT", "NLP β€” Encoder", "Distilled BERT (6-layer)"),
(r"albert\.encoder\.albert_layer_groups", "ALBERT", "NLP β€” Encoder", "A Lite BERT with parameter sharing"),
(r"electra\.encoder\.layer\.\d+", "ELECTRA", "NLP β€” Encoder", "Efficiently Learning an Encoder that Classifies Token Replacements"),
(r"deberta\.encoder\.layer\.\d+\.attention", "DeBERTa", "NLP β€” Encoder", "Decoding-enhanced BERT with disentangled attention"),
(r"deberta_v2\.encoder|debertav2", "DeBERTa-v2", "NLP β€” Encoder", "DeBERTa version 2"),
(r"xlm_roberta\.encoder|xlm-roberta", "XLM-RoBERTa", "NLP β€” Encoder", "Cross-lingual RoBERTa"),
(r"camembert\.encoder", "CamemBERT", "NLP β€” Encoder", "French language BERT"),
(r"flaubert\.encoder", "FlauBERT", "NLP β€” Encoder", "French language model"),
(r"funnel\.encoder|funnel_transformer", "Funnel Transformer", "NLP β€” Encoder", "Funnel-shaped encoder"),
(r"layoutlm\.encoder|layoutlmv2|layoutlmv3", "LayoutLM", "NLP β€” Encoder", "Document understanding model"),
(r"longformer\.encoder\.layer\.\d+", "Longformer", "NLP β€” Encoder", "Long document transformer"),
(r"bigbird\.encoder|block_sparse", "BigBird", "NLP β€” Encoder", "Big Bird sparse attention"),
(r"ernie\.encoder", "ERNIE", "NLP β€” Encoder", "Enhanced Representation from kNowledge IntEgration"),
(r"canine\.encoder", "CANINE", "NLP β€” Encoder", "Character-level encoder"),
(r"roformer\.encoder", "RoFormer", "NLP β€” Encoder", "Rotary position embedding transformer"),
(r"megatron\.encoder|megatron_bert", "Megatron-BERT", "NLP β€” Encoder", "NVIDIA Megatron BERT"),
(r"luke\.encoder", "LUKE", "NLP β€” Encoder", "Language Understanding with Knowledge-based Embeddings"),
(r"markuplm\.encoder", "MarkupLM", "NLP β€” Encoder", "Markup language model"),
(r"splinter\.encoder", "Splinter", "NLP β€” Encoder", "Few-shot QA model"),
(r"squeezebert\.encoder", "SqueezeBERT", "NLP β€” Encoder", "Efficient BERT variant"),
(r"mpnet\.encoder", "MPNet", "NLP β€” Encoder", "Masked and Permuted Pre-training"),
(r"convbert\.encoder", "ConvBERT", "NLP β€” Encoder", "BERT with span-based dynamic convolution"),
(r"ibert\.encoder", "I-BERT", "NLP β€” Encoder", "Integer-only BERT quantization"),
(r"rembert\.encoder", "RemBERT", "NLP β€” Encoder", "Rethinking Embedding coupling in BERT"),
(r"tapas\.encoder", "TAPAS", "NLP β€” Encoder", "Table parsing transformer"),
(r"mobilebert\.encoder", "MobileBERT", "NLP β€” Encoder", "Compact BERT for mobile"),
(r"nystromformer\.encoder", "NystrΓΆmformer", "NLP β€” Encoder", "NystrΓΆm-based approximate attention"),
# ─────────────────── NLP: Sentence Transformers / Embedding ───────────────────
(r"sentence_bert|sbert|sentence_transformers", "Sentence-BERT", "NLP β€” Embedding", "Sentence-level embeddings"),
(r"0\.auto_model\.|1\.pooling|2_dense|2_Normalize", "Sentence Transformers", "NLP β€” Embedding", "HuggingFace Sentence Transformers pipeline"),
(r"e5_model|e5\.encoder|intfloat", "E5", "NLP β€” Embedding", "Text embedding model"),
(r"bge\.encoder|bge_model", "BGE", "NLP β€” Embedding", "BAAI General Embedding"),
(r"gte\.encoder|gte_model", "GTE", "NLP β€” Embedding", "General Text Embedding by Alibaba"),
(r"instructor\.encoder|instructor_model", "Instructor", "NLP β€” Embedding", "Instruction-finetuned embedding"),
(r"jina\.encoder|jina_model", "Jina Embeddings", "NLP β€” Embedding", "Jina AI embedding model"),
# ─────────────────── NLP: Decoder / Causal LM ───────────────────
(r"transformer\.h\.\d+\.attn\.c_attn", "GPT-2", "NLP β€” Causal LM", "OpenAI GPT-2"),
(r"transformer\.h\.\d+\.mlp\.c_fc", "GPT-2", "NLP β€” Causal LM", "GPT-2 style architecture"),
(r"gpt_neox\.layers\.\d+", "GPT-NeoX", "NLP β€” Causal LM", "EleutherAI GPT-NeoX"),
(r"gpt_neo\.transformer", "GPT-Neo", "NLP β€” Causal LM", "EleutherAI GPT-Neo"),
(r"gptj\.transformer|gpt_j", "GPT-J", "NLP β€” Causal LM", "EleutherAI GPT-J 6B"),
(r"model\.layers\.\d+\.self_attn\.q_proj.*model\.layers\.\d+\.self_attn\.k_proj", "LLaMA / LLaMA-2 / LLaMA-3", "NLP β€” Causal LM", "Meta LLaMA family"),
(r"model\.layers\.\d+\.self_attn\.rotary_emb", "LLaMA-style (RoPE)", "NLP β€” Causal LM", "LLaMA architecture with rotary embeddings"),
(r"model\.layers\.\d+\.mlp\.gate_proj", "LLaMA-style (Gated MLP)", "NLP β€” Causal LM", "LLaMA / Mistral / Qwen style gated FFN"),
(r"mistral\.layers|model\.layers.*mistral", "Mistral", "NLP β€” Causal LM", "Mistral AI model"),
(r"mixtral\.layers|experts\.\d+\.w1", "Mixtral (MoE)", "NLP β€” Causal LM", "Mistral Mixture of Experts"),
(r"block_sparse_moe|\.experts\.\d+\.", "Mixture of Experts", "NLP β€” Causal LM", "Sparse MoE architecture"),
(r"phi\.layers|phi3|phi-2|phi-1", "Phi", "NLP β€” Causal LM", "Microsoft Phi series"),
(r"gemma\.layers|model\.layers.*gemma", "Gemma", "NLP β€” Causal LM", "Google Gemma"),
(r"qwen\.layers|qwen2|qwen_model", "Qwen", "NLP β€” Causal LM", "Alibaba Qwen series"),
(r"internlm\.layers|internlm2", "InternLM", "NLP β€” Causal LM", "Shanghai AI Lab InternLM"),
(r"baichuan\.layers|baichuan2", "Baichuan", "NLP β€” Causal LM", "Baichuan Inc model"),
(r"yi\.layers|yi_model", "Yi", "NLP β€” Causal LM", "01.AI Yi series"),
(r"deepseek\.layers|deepseek_v2", "DeepSeek", "NLP β€” Causal LM", "DeepSeek AI model"),
(r"chatglm\.layers|chatglm2|chatglm3", "ChatGLM", "NLP β€” Causal LM", "Zhipu ChatGLM"),
(r"bloom\.transformer\.h\.\d+", "BLOOM", "NLP β€” Causal LM", "BigScience BLOOM"),
(r"opt\.decoder\.layers\.\d+", "OPT", "NLP β€” Causal LM", "Meta OPT"),
(r"codegen\.transformer", "CodeGen", "NLP β€” Causal LM", "Salesforce code generation"),
(r"starcoder\.transformer|starcoder2", "StarCoder", "NLP β€” Causal LM", "BigCode StarCoder"),
(r"codellama|code_llama", "Code Llama", "NLP β€” Causal LM", "Meta Code Llama"),
(r"falcon\.transformer\.h|falcon\.layers", "Falcon", "NLP β€” Causal LM", "TII Falcon"),
(r"persimmon\.layers", "Persimmon", "NLP β€” Causal LM", "Adept Persimmon"),
(r"mpt\.blocks\.\d+", "MPT", "NLP β€” Causal LM", "MosaicML MPT"),
(r"rwkv\.blocks|rwkv\.layers", "RWKV", "NLP β€” Causal LM", "Receptance Weighted Key Value (RNN-style)"),
(r"mamba\.layers|mamba_block|mixer\.A_log", "Mamba", "NLP β€” Causal LM", "State Space Model"),
(r"state_space|s4\.kernel|ssm\.A_log", "State Space Model (SSM)", "NLP β€” Causal LM", "S4/Mamba/Hyena SSM variant"),
(r"recurrent_gemma|griffin", "RecurrentGemma / Griffin", "NLP β€” Causal LM", "Recurrent variant of Gemma"),
(r"cohere\.layers|command_r", "Cohere Command", "NLP β€” Causal LM", "Cohere Command-R"),
(r"olmo\.layers|olmo_model", "OLMo", "NLP β€” Causal LM", "AI2 Open Language Model"),
(r"stablelm\.layers|stablelm_model", "StableLM", "NLP β€” Causal LM", "Stability AI StableLM"),
(r"dbrx\.layers|dbrx_model", "DBRX", "NLP β€” Causal LM", "Databricks DBRX"),
(r"jamba\.layers|jamba_model", "Jamba", "NLP β€” Causal LM", "AI21 Jamba (SSM+Attention hybrid)"),
(r"arctic\.layers|arctic_model", "Arctic", "NLP β€” Causal LM", "Snowflake Arctic"),
(r"xverse\.layers|xverse_model", "XVERSE", "NLP β€” Causal LM", "XVERSE language model"),
(r"orion\.layers|orion_model", "Orion", "NLP β€” Causal LM", "OrionStar AI model"),
(r"pythia\.gpt_neox", "Pythia", "NLP β€” Causal LM", "EleutherAI Pythia suite"),
(r"cerebras\.layers|cerebras_gpt", "Cerebras-GPT", "NLP β€” Causal LM", "Cerebras GPT"),
(r"tinyllama|tiny_llama", "TinyLlama", "NLP β€” Causal LM", "Small LLaMA variant"),
(r"openelm\.layers", "OpenELM", "NLP β€” Causal LM", "Apple OpenELM"),
# ─────────────────── NLP: Encoder-Decoder / Seq2Seq ───────────────────
(r"encoder\.block\.\d+\.layer.*decoder\.block\.\d+\.layer", "T5", "NLP β€” Seq2Seq", "Text-to-Text Transfer Transformer"),
(r"t5\.encoder|t5_model|shared\.weight.*encoder\.block", "T5", "NLP β€” Seq2Seq", "T5 / Flan-T5 / mT5"),
(r"model\.encoder\.layers.*model\.decoder\.layers.*model\.decoder\.embed_tokens", "BART", "NLP β€” Seq2Seq", "Bidirectional and Auto-Regressive Transformer"),
(r"mbart\.encoder|mbart_model", "mBART", "NLP β€” Seq2Seq", "Multilingual BART"),
(r"pegasus\.encoder", "Pegasus", "NLP β€” Seq2Seq", "Pre-training with Extracted Gap-sentences"),
(r"led\.encoder\.layers|longformer_encoder_decoder", "LED", "NLP β€” Seq2Seq", "Longformer Encoder-Decoder"),
(r"marian\.encoder|marian_model", "MarianMT", "NLP β€” Seq2Seq", "Marian machine translation"),
(r"opus_mt|helsinki_nlp", "OPUS-MT", "NLP β€” Seq2Seq", "OPUS machine translation"),
(r"nllb\.encoder|nllb_model", "NLLB", "NLP β€” Seq2Seq", "No Language Left Behind"),
(r"switch_transformer\.encoder.*experts", "Switch Transformer", "NLP β€” Seq2Seq", "Switch Transformer (MoE T5)"),
(r"ul2\.encoder|ul2_model", "UL2", "NLP β€” Seq2Seq", "Unifying Language Learning Paradigms"),
(r"flan\.encoder|flan_model", "Flan-T5 / Flan-UL2", "NLP β€” Seq2Seq", "Instruction-tuned T5/UL2"),
(r"longt5\.encoder", "LongT5", "NLP β€” Seq2Seq", "Long-context T5"),
(r"blenderbot\.encoder", "BlenderBot", "NLP β€” Seq2Seq", "Meta BlenderBot dialogue"),
(r"prophetnet\.encoder", "ProphetNet", "NLP β€” Seq2Seq", "Microsoft ProphetNet"),
(r"plbart\.encoder", "PLBART", "NLP β€” Seq2Seq", "Program and Language BART"),
(r"codet5\.encoder|codet5p", "CodeT5", "NLP β€” Seq2Seq", "Salesforce CodeT5"),
(r"xlm\.encoder.*xlm\.decoder", "XLM", "NLP β€” Seq2Seq", "Cross-lingual Language Model"),
# ─────────────────── Vision: Image Classification ───────────────────
(r"vit\.encoder\.layer\.\d+|\.cls_token|\.pos_embed", "ViT", "Vision β€” Classification", "Vision Transformer"),
(r"deit\.encoder|deit_model", "DeiT", "Vision β€” Classification", "Data-efficient Image Transformer"),
(r"beit\.encoder|beit_model", "BEiT", "Vision β€” Classification", "BERT Pre-Training of Image Transformers"),
(r"swin\.encoder|swin_model|swin\.layers", "Swin Transformer", "Vision β€” Classification", "Shifted Window Transformer"),
(r"swinv2\.encoder", "Swin Transformer v2", "Vision β€” Classification", "Swin v2"),
(r"convnext\.stages|convnext\.encoder", "ConvNeXt", "Vision β€” Classification", "A ConvNet for the 2020s"),
(r"resnet\.layer[1-4]|resnet\.conv1|\.downsample", "ResNet", "Vision β€” Classification", "Residual Network"),
(r"resnext|res2net", "ResNeXt / Res2Net", "Vision β€” Classification", "ResNeXt / Res2Net variant"),
(r"efficientnet\.features|efficientnet_b[0-9]", "EfficientNet", "Vision β€” Classification", "EfficientNet scaling"),
(r"efficientnetv2", "EfficientNet-v2", "Vision β€” Classification", "EfficientNet v2"),
(r"mobilenet_v2|mobilenetv2", "MobileNetV2", "Vision β€” Classification", "Mobile-optimized ConvNet"),
(r"mobilenet_v3|mobilenetv3", "MobileNetV3", "Vision β€” Classification", "MobileNetV3"),
(r"densenet\.features|denseblock", "DenseNet", "Vision β€” Classification", "Densely Connected ConvNet"),
(r"inception\.Mixed|InceptionV3|inception_v3", "Inception", "Vision β€” Classification", "Google Inception"),
(r"vgg\.(features|classifier)", "VGG", "Vision β€” Classification", "VGG Network"),
(r"squeezenet\.features", "SqueezeNet", "Vision β€” Classification", "SqueezeNet"),
(r"shufflenet|shuffle_net", "ShuffleNet", "Vision β€” Classification", "ShuffleNet"),
(r"maxvit\.encoder|maxvit\.stem", "MaxViT", "Vision β€” Classification", "Multi-Axis Vision Transformer"),
(r"poolformer\.encoder|pool_forking", "PoolFormer", "Vision β€” Classification", "MetaFormer with pooling"),
(r"dinov2\.encoder|dinov2_model", "DINOv2", "Vision β€” Classification", "Meta DINOv2 self-supervised ViT"),
(r"eva\.encoder|eva02", "EVA / EVA-02", "Vision β€” Classification", "EVA vision model"),
(r"regnet\.stem|regnet\.stages", "RegNet", "Vision β€” Classification", "Designing Network Design Spaces"),
(r"coat\.encoder|coat_model", "CoaT", "Vision β€” Classification", "Co-Scale Conv-Attentional Transformer"),
(r"pvt\.encoder|pvt_v2", "PVT / PVTv2", "Vision β€” Classification", "Pyramid Vision Transformer"),
(r"nat\.encoder|nat_model", "NAT", "Vision β€” Classification", "Neighborhood Attention Transformer"),
(r"dinat\.encoder", "DiNAT", "Vision β€” Classification", "Dilated Neighborhood Attention"),
(r"levit\.encoder|levit_model", "LeViT", "Vision β€” Classification", "LeViT fast inference ViT"),
(r"cait\.encoder|cait_model", "CaiT", "Vision β€” Classification", "Class-Attention in Image Transformers"),
(r"crossvit\.encoder", "CrossViT", "Vision β€” Classification", "Cross-Attention Multi-Scale ViT"),
(r"xcit\.encoder", "XCiT", "Vision β€” Classification", "Cross-Covariance Image Transformer"),
(r"mlp_mixer\.stem|mixer\.layers", "MLP-Mixer", "Vision β€” Classification", "All-MLP Architecture"),
# ─────────────────── Vision: Object Detection ───────────────────
(r"detr\.encoder|detr\.decoder|detr\.model", "DETR", "Vision β€” Detection", "Detection Transformer"),
(r"deformable_detr\.encoder", "Deformable DETR", "Vision β€” Detection", "Deformable DETR"),
(r"conditional_detr", "Conditional DETR", "Vision β€” Detection", "Conditional DETR"),
(r"deta\.encoder|deta_model", "DETA", "Vision β€” Detection", "Detection Transformers with Assignment"),
(r"yolos\.encoder|yolos_model", "YOLOS", "Vision β€” Detection", "You Only Look at One Sequence"),
(r"rt_detr\.encoder|rtdetr", "RT-DETR", "Vision β€” Detection", "Real-Time DETR"),
(r"grounding_dino\.encoder", "Grounding DINO", "Vision β€” Detection", "Open-set object detection"),
(r"faster_rcnn|fasterrcnn", "Faster R-CNN", "Vision β€” Detection", "Faster R-CNN"),
(r"mask_rcnn|maskrcnn", "Mask R-CNN", "Vision β€” Detection", "Mask R-CNN (instance segmentation)"),
(r"retinanet\.backbone", "RetinaNet", "Vision β€” Detection", "Focal Loss detector"),
(r"yolo[v]?\d|ultralytics", "YOLO", "Vision β€” Detection", "You Only Look Once"),
(r"owlvit\.vision|owlv2", "OWL-ViT", "Vision β€” Detection", "Open-World Localization ViT"),
# ─────────────────── Vision: Segmentation ───────────────────
(r"segformer\.encoder", "SegFormer", "Vision β€” Segmentation", "Semantic segmentation transformer"),
(r"mask2former\.encoder|mask2former\.pixel", "Mask2Former", "Vision β€” Segmentation", "Masked-attention Mask Transformer"),
(r"maskformer\.encoder", "MaskFormer", "Vision β€” Segmentation", "Per-Pixel Classification is Not All You Need"),
(r"sam\.image_encoder|sam\.mask_decoder|segment_anything", "SAM", "Vision β€” Segmentation", "Segment Anything Model"),
(r"sam2\.image_encoder|sam_hq", "SAM 2 / SAM-HQ", "Vision β€” Segmentation", "Segment Anything 2"),
(r"oneformer\.encoder", "OneFormer", "Vision β€” Segmentation", "One Transformer to Rule Universal Segmentation"),
(r"upernet\.backbone", "UPerNet", "Vision β€” Segmentation", "Unified Perceptual Parsing"),
(r"deeplab|deeplabv3", "DeepLabV3", "Vision β€” Segmentation", "DeepLab semantic segmentation"),
(r"fcn\.backbone", "FCN", "Vision β€” Segmentation", "Fully Convolutional Network"),
# ─────────────────── Vision: Depth / 3D ───────────────────
(r"dpt\.encoder|dpt_model", "DPT", "Vision β€” Depth", "Dense Prediction Transformer"),
(r"depth_anything|depth\.encoder", "Depth Anything", "Vision β€” Depth", "Depth Anything model"),
(r"glpn\.encoder", "GLPN", "Vision β€” Depth", "Global-Local Path Networks"),
(r"zoedepth\.encoder", "ZoeDepth", "Vision β€” Depth", "Zero-shot monocular depth"),
(r"midas\.encoder|midas_model", "MiDaS", "Vision β€” Depth", "Monocular depth estimation"),
(r"nerf\.encoder|nerf_model", "NeRF", "Vision β€” 3D", "Neural Radiance Fields"),
(r"point_cloud|pointnet|pointnet2", "PointNet", "Vision β€” 3D", "Point cloud processing"),
# ─────────────────── Generative: Diffusion Models ───────────────────
(r"model\.diffusion_model\.(input_blocks|middle_block|output_blocks)", "Stable Diffusion (UNet)", "Generative β€” Diffusion", "Stability AI Stable Diffusion UNet"),
(r"unet\.(down_blocks|mid_block|up_blocks)", "Diffusers UNet2D", "Generative β€” Diffusion", "HuggingFace Diffusers UNet"),
(r"transformer_blocks.*\.attn1|transformer_blocks.*\.attn2", "Diffusion Transformer Block", "Generative β€” Diffusion", "DiT-style attention"),
(r"dit\.blocks|dit_model|DiT", "DiT", "Generative β€” Diffusion", "Diffusion Transformer"),
(r"pixart\.transformer|pixart_model", "PixArt", "Generative β€” Diffusion", "PixArt image generation"),
(r"sdxl\.|sd_xl|xl_base", "SDXL", "Generative β€” Diffusion", "Stable Diffusion XL"),
(r"sd3\.transformer|sd3_model|stable_diffusion_3", "SD3", "Generative β€” Diffusion", "Stable Diffusion 3"),
(r"flux\.transformer|flux_model|flux\.double_blocks", "FLUX", "Generative β€” Diffusion", "Black Forest Labs FLUX"),
(r"kandinsky\.unet|kandinsky_model", "Kandinsky", "Generative β€” Diffusion", "Kandinsky image generation"),
(r"wuerstchen|paella|stage_c|stage_b", "WΓΌrstchen", "Generative β€” Diffusion", "Efficient diffusion"),
(r"playground\.transformer|playground_model", "Playground", "Generative β€” Diffusion", "PlaygroundAI model"),
(r"imagen\.unet|imagen_model", "Imagen", "Generative β€” Diffusion", "Google Imagen"),
(r"consistency_model|consistency_decoder", "Consistency Model", "Generative β€” Diffusion", "Consistency model"),
(r"latent_diffusion\.model|ldm\.", "Latent Diffusion", "Generative β€” Diffusion", "Latent Diffusion Model"),
(r"controlnet\.input_hint_block|controlnet\.zero_convs", "ControlNet", "Generative β€” Diffusion", "ControlNet conditioning"),
(r"ip_adapter\.image_proj|ip_adapter_model", "IP-Adapter", "Generative β€” Diffusion", "Image Prompt Adapter"),
(r"lora_down|lora_up|lora\.weight", "LoRA Weights", "Generative β€” Fine-tune", "Low-Rank Adaptation weights"),
(r"text_model\.encoder.*text_projection", "CLIP Text Encoder", "Generative β€” Diffusion", "CLIP text encoder for diffusion"),
# ─────────────────── Generative: VAE ───────────────────
(r"first_stage_model\.(encoder|decoder)", "SD VAE (first_stage)", "Generative β€” VAE", "Stable Diffusion VAE"),
(r"vae\.(encoder|decoder)\..*\.conv", "VAE (Conv)", "Generative β€” VAE", "Convolutional VAE"),
(r"vq_vae|vqvae|vector_quantizer", "VQ-VAE", "Generative β€” VAE", "Vector-Quantized VAE"),
(r"vqgan\.encoder|vqgan\.decoder", "VQGAN", "Generative β€” VAE", "Vector-Quantized GAN"),
# ─────────────────── Generative: GAN ───────────────────
(r"generator\.(conv|blocks|layers)|\.gen\.", "GAN Generator", "Generative β€” GAN", "GAN generator network"),
(r"discriminator\.(conv|blocks|layers)|\.disc\.", "GAN Discriminator", "Generative β€” GAN", "GAN discriminator network"),
(r"stylegan\.synthesis|style_mapping", "StyleGAN", "Generative β€” GAN", "Style-based GAN"),
(r"pix2pix\.generator|pix2pix_model", "Pix2Pix", "Generative β€” GAN", "Image-to-image translation"),
(r"cyclegan\.gen|cycle_gan", "CycleGAN", "Generative β€” GAN", "Unpaired image-to-image"),
(r"esrgan\.model|rrdb|real_esrgan", "ESRGAN / Real-ESRGAN", "Generative β€” GAN", "Enhanced super-resolution GAN"),
(r"gfpgan\.model|gfpgan", "GFPGAN", "Generative β€” GAN", "Face restoration GAN"),
# ─────────────────── Generative: Image ───────────────────
(r"maskgit\.encoder|maskgit_model", "MaskGiT", "Generative β€” Image", "Masked Generative Image Transformer"),
(r"parti\.encoder|parti_model", "Parti", "Generative β€” Image", "Pathways Autoregressive Text-to-Image"),
# ─────────────────── Audio / Speech ───────────────────
(r"whisper\.encoder\.layers|whisper\.decoder", "Whisper", "Audio β€” ASR", "OpenAI Whisper speech recognition"),
(r"wav2vec2\.encoder\.layers", "Wav2Vec2", "Audio β€” ASR", "Facebook Wav2Vec 2.0"),
(r"hubert\.encoder\.layers", "HuBERT", "Audio β€” ASR", "Hidden-Unit BERT for speech"),
(r"wavlm\.encoder\.layers", "WavLM", "Audio β€” ASR", "Wave Language Model"),
(r"data2vec_audio\.encoder", "Data2Vec Audio", "Audio β€” ASR", "Data2Vec for audio"),
(r"unispeech\.encoder", "UniSpeech", "Audio β€” ASR", "Unified Speech representation"),
(r"sew\.encoder|sew_d\.encoder", "SEW / SEW-D", "Audio β€” ASR", "Squeezed and Efficient Wav2Vec"),
(r"speech_t5\.encoder|speecht5", "SpeechT5", "Audio β€” ASR/TTS", "Unified speech-text model"),
(r"seamless_m4t\.encoder|seamless_model", "SeamlessM4T", "Audio β€” Translation", "Meta Seamless multilingual"),
(r"encodec\.encoder|encodec_model", "EnCodec", "Audio β€” Codec", "Meta neural audio codec"),
(r"bark\.semantic|bark\.coarse|bark\.fine", "Bark", "Audio β€” TTS", "Suno Bark text-to-speech"),
(r"vits\.encoder|vits_model", "VITS", "Audio β€” TTS", "VITS text-to-speech"),
(r"speedy_speech|fastspeech", "FastSpeech", "Audio β€” TTS", "FastSpeech TTS"),
(r"tacotron\.encoder|tacotron2", "Tacotron", "Audio β€” TTS", "Tacotron text-to-speech"),
(r"musicgen\.encoder|musicgen_model", "MusicGen", "Audio β€” Music", "Meta music generation"),
(r"audioldm\.unet|audioldm_model", "AudioLDM", "Audio β€” Generation", "Audio latent diffusion"),
(r"clap\.audio_encoder|clap\.text_encoder", "CLAP", "Audio β€” Multimodal", "Contrastive Language-Audio Pretraining"),
(r"audio_spectrogram_transformer|ast\.encoder", "AST", "Audio β€” Classification", "Audio Spectrogram Transformer"),
(r"beats\.encoder|beats_model", "BEATs", "Audio β€” Classification", "Audio pre-training"),
# ─────────────────── Multimodal ───────────────────
(r"clip\.visual\.transformer|clip\.text\.transformer|visual_projection|text_projection", "CLIP", "Multimodal", "Contrastive Language-Image Pre-training"),
(r"openclip\.visual|open_clip", "OpenCLIP", "Multimodal", "Open-source CLIP"),
(r"siglip\.vision|siglip\.text", "SigLIP", "Multimodal", "Sigmoid Loss for Language-Image Pre-training"),
(r"blip\.vision_model|blip\.text_encoder", "BLIP", "Multimodal", "Bootstrapping Language-Image Pre-training"),
(r"blip2\.vision_model|blip2\.qformer", "BLIP-2", "Multimodal", "BLIP-2 with Q-Former"),
(r"instructblip\.vision|instructblip\.qformer", "InstructBLIP", "Multimodal", "Instruction-tuned BLIP"),
(r"llava\.vision_tower|llava\.mm_projector|llava\.language_model", "LLaVA", "Multimodal β€” VLM", "Large Language and Vision Assistant"),
(r"llava_next\.vision|llava_next_model", "LLaVA-NeXT", "Multimodal β€” VLM", "LLaVA-NeXT improved"),
(r"idefics\.vision|idefics2", "IDEFICS", "Multimodal β€” VLM", "HuggingFace IDEFICS"),
(r"fuyu\.vision|fuyu_model", "Fuyu", "Multimodal β€” VLM", "Adept Fuyu multimodal"),
(r"paligemma\.vision|paligemma_model", "PaliGemma", "Multimodal β€” VLM", "Google PaliGemma"),
(r"cogvlm\.vision|cogvlm_model", "CogVLM", "Multimodal β€” VLM", "Tsinghua CogVLM"),
(r"qwen_vl\.visual|qwen_vl_model", "Qwen-VL", "Multimodal β€” VLM", "Alibaba Qwen Visual"),
(r"internvl\.vision|internvl_model", "InternVL", "Multimodal β€” VLM", "InternVL vision-language"),
(r"florence\.vision|florence_model", "Florence", "Multimodal", "Microsoft Florence"),
(r"kosmos\.vision|kosmos_model", "Kosmos", "Multimodal", "Microsoft Kosmos"),
(r"align\.vision|align\.text", "ALIGN", "Multimodal", "Google ALIGN"),
(r"flava\.vision|flava\.text", "FLAVA", "Multimodal", "Foundational Language And Vision Alignment"),
(r"bridgetower\.vision|bridgetower\.text", "BridgeTower", "Multimodal", "Bridge visual and text"),
(r"chinese_clip\.vision|chinese_clip\.text", "Chinese CLIP", "Multimodal", "Chinese language CLIP"),
(r"git\.vision|git_model\.image_encoder", "GIT", "Multimodal", "Generative Image-to-Text"),
# ─────────────────── Video ───────────────────
(r"videomae\.encoder", "VideoMAE", "Video", "Video Masked Autoencoder"),
(r"vivit\.encoder", "ViViT", "Video", "Video Vision Transformer"),
(r"timesformer\.encoder", "TimeSformer", "Video", "Time-Space Transformer"),
(r"x_clip\.vision|xclip_model", "X-CLIP", "Video", "Video-language CLIP"),
(r"cogvideo\.transformer", "CogVideo", "Video β€” Generation", "Video generation model"),
(r"animatediff\.motion|animatediff_model", "AnimateDiff", "Video β€” Generation", "Animated image diffusion"),
(r"svd\.unet|stable_video_diffusion", "SVD", "Video β€” Generation", "Stable Video Diffusion"),
# ─────────────────── Reinforcement Learning ───────────────────
(r"decision_transformer\.encoder", "Decision Transformer", "RL", "Decision Transformer"),
(r"policy_network|actor\.layers|actor_critic", "Actor-Critic / Policy Net", "RL", "RL policy network"),
(r"q_network|critic\.layers|dqn", "Q-Network / DQN", "RL", "RL Q-value network"),
(r"ppo\.policy|ppo_model", "PPO Model", "RL", "Proximal Policy Optimization"),
(r"sac\.actor|sac\.critic", "SAC Model", "RL", "Soft Actor-Critic"),
# ─────────────────── Science / Domain-Specific ───────────────────
(r"esm\.encoder\.layer|esm2\.encoder|esm1b", "ESM / ESM-2", "Science β€” Protein", "Evolutionary Scale Modeling (protein)"),
(r"esmfold\.encoder", "ESMFold", "Science β€” Protein", "ESM protein structure prediction"),
(r"prot_bert|protbert", "ProtBERT", "Science β€” Protein", "Protein BERT"),
(r"alphafold\.structure|evoformer", "AlphaFold-style", "Science β€” Protein", "Protein structure prediction"),
(r"schnet\.interactions|dimenet|painn", "SchNet / DimeNet / PaiNN", "Science β€” Molecular", "Molecular property prediction"),
(r"graphormer\.encoder|graph_transformer", "Graphormer", "Science β€” Graph", "Graph Transformer"),
(r"gin\.layers|gat\.layers|gcn\.layers|gnn\.", "GNN (GCN/GAT/GIN)", "Science β€” Graph", "Graph Neural Network"),
(r"megamolbart|molbart", "MolBART", "Science β€” Chemistry", "Molecular BART"),
(r"mat\.encoder|molecular_attention", "MAT", "Science β€” Chemistry", "Molecular Attention Transformer"),
# ─────────────────── Time Series ───────────────────
(r"patchtst\.encoder", "PatchTST", "Time Series", "Patch Time Series Transformer"),
(r"informer\.encoder|informer\.decoder", "Informer", "Time Series", "Efficient transformer for long sequences"),
(r"autoformer\.encoder", "Autoformer", "Time Series", "Auto-Correlation transformer"),
(r"time_series_transformer\.encoder", "Time Series Transformer", "Time Series", "Generic time series model"),
(r"timesnet\.encoder|timesnet_model", "TimesNet", "Time Series", "Temporal 2D-variation model"),
# ─────────────────── Document / OCR ───────────────────
(r"donut\.encoder|donut\.decoder", "Donut", "Document β€” OCR", "Document Understanding Transformer"),
(r"trocr\.encoder|trocr\.decoder", "TrOCR", "Document β€” OCR", "Transformer OCR"),
(r"nougat\.encoder|nougat\.decoder", "Nougat", "Document β€” OCR", "Academic document understanding"),
(r"pix2struct\.encoder", "Pix2Struct", "Document", "Screenshot to structured data"),
(r"table_transformer\.encoder", "Table Transformer", "Document", "Table detection/recognition"),
# ─────────────────── Recommendation / Retrieval ───────────────────
(r"retribert\.encoder", "RetriBERT", "Retrieval", "Retrieval BERT"),
(r"dpr\.question_encoder|dpr\.ctx_encoder", "DPR", "Retrieval", "Dense Passage Retrieval"),
(r"colbert\.encoder|colbert_model", "ColBERT", "Retrieval", "Contextualized Late Interaction BERT"),
(r"splade\.encoder|splade_model", "SPLADE", "Retrieval", "Sparse Lexical and Expansion"),
# ─────────────────── Adapters / PEFT ───────────────────
(r"lora_A|lora_B|lora_embedding", "LoRA Adapter", "PEFT", "Low-Rank Adaptation"),
(r"adapter_down|adapter_up|adapter\.weight", "Bottleneck Adapter", "PEFT", "Adapter layers"),
(r"prefix_encoder|prefix_tuning", "Prefix Tuning", "PEFT", "Prefix tuning parameters"),
(r"prompt_encoder|prompt_embeddings|soft_prompt", "Prompt Tuning", "PEFT", "Soft prompt parameters"),
(r"ia3_l|ia3\.weight", "IAΒ³", "PEFT", "Infused Adapter by Inhibiting and Amplifying"),
(r"qlora|quantized.*lora", "QLoRA", "PEFT", "Quantized LoRA"),
# ─────────────────── Quantization ───────────────────
(r"\.qweight|\.qzeros|\.scales.*gptq", "GPTQ Quantized", "Quantization", "GPTQ post-training quantization"),
(r"quant_state|absmax|bnb_quantized", "BitsAndBytes (bnb)", "Quantization", "bitsandbytes quantization"),
(r"awq\.qweight|awq_model", "AWQ Quantized", "Quantization", "Activation-aware Weight Quantization"),
(r"gguf|ggml", "GGUF/GGML", "Quantization", "llama.cpp quantization format"),
# ─────────────────── Misc / Catch-all Patterns ───────────────────
(r"\.self_attn\.(q|k|v)_proj", "Generic Transformer (Q/K/V projections)", "Architecture Pattern", "Standard transformer self-attention"),
(r"\.cross_attn\.", "Cross-Attention Module", "Architecture Pattern", "Cross-attention between modalities"),
(r"rope_|rotary_emb|\.cos_cached|\.sin_cached", "Rotary Position Embeddings (RoPE)", "Architecture Pattern", "Rotary position encoding"),
(r"alibi|attention_bias", "ALiBi Position", "Architecture Pattern", "Attention with Linear Biases"),
(r"group_query|num_key_value", "Grouped-Query Attention (GQA)", "Architecture Pattern", "GQA attention"),
(r"flash_attn|flash_attention", "Flash Attention", "Architecture Pattern", "IO-aware attention"),
(r"rmsnorm|rms_norm", "RMSNorm", "Architecture Pattern", "Root Mean Square Layer Normalization"),
(r"swiglu|silu.*gate|gate.*silu", "SwiGLU Activation", "Architecture Pattern", "SwiGLU gated activation"),
]
def detect_architectures(state_dict: OrderedDict) -> list:
"""
Detect all matching architecture patterns from a state_dict.
Returns list of dicts: {family, category, description, confidence, matched_keys}
"""
all_keys = " ".join(state_dict.keys())
# Also build a single big string with dots replaced by spaces for better matching
all_keys_flat = all_keys.replace(".", " ").replace("_", " ")
combined = all_keys + " " + all_keys_flat
matches = []
seen_families = set()
for pattern, family, category, description in ARCH_PATTERNS:
found = re.findall(pattern, combined, re.IGNORECASE)
if found:
if family not in seen_families:
# Count how many unique keys matched
matched_keys = [k for k in state_dict.keys() if re.search(pattern, k, re.IGNORECASE)]
hit_ratio = len(matched_keys) / max(len(state_dict), 1)
if hit_ratio > 0.3:
confidence = "🟒 HIGH"
elif hit_ratio > 0.05:
confidence = "🟑 MEDIUM"
else:
confidence = "🟠 LOW"
matches.append({
"family": family,
"category": category,
"description": description,
"confidence": confidence,
"hit_count": len(found),
"matched_keys_count": len(matched_keys),
"hit_ratio": hit_ratio,
"sample_keys": matched_keys[:5],
})
seen_families.add(family)
# Sort by confidence then hit count
priority = {"🟒 HIGH": 0, "🟑 MEDIUM": 1, "🟠 LOW": 2}
matches.sort(key=lambda m: (priority.get(m["confidence"], 3), -m["hit_count"]))
return matches
def infer_model_config(state_dict: OrderedDict, detections: list) -> dict:
"""
Infer model configuration details from the state_dict:
hidden size, num layers, num heads, vocab size, max seq length, etc.
"""
config = {}
keys = list(state_dict.keys())
key_str = " ".join(keys)
# ── Hidden size ──
hidden_dims = []
for k, t in state_dict.items():
if t.ndim == 2 and ("weight" in k) and ("embed" not in k.lower()):
hidden_dims.extend([t.shape[0], t.shape[1]])
if hidden_dims:
from collections import Counter
dim_counts = Counter(hidden_dims)
most_common = dim_counts.most_common(5)
config["likely_hidden_size"] = most_common[0][0]
config["common_dimensions"] = [(d, c) for d, c in most_common]
# ── Number of layers ──
layer_indices = set()
for pattern in [r"\.(\d+)\.", r"layer\.(\d+)", r"layers\.(\d+)", r"block\.(\d+)",
r"blocks\.(\d+)", r"h\.(\d+)", r"transformer\.h\.(\d+)"]:
for match in re.finditer(pattern, key_str):
layer_indices.add(int(match.group(1)))
if layer_indices:
config["num_layers"] = max(layer_indices) + 1
# ── Vocab size ──
for k, t in state_dict.items():
lower_k = k.lower()
if any(x in lower_k for x in ["embed_tokens", "word_embed", "wte", "token_embedding",
"embed.weight", "embeddings.word"]):
if t.ndim == 2:
config["vocab_size"] = t.shape[0]
config["embedding_dim"] = t.shape[1]
break
# ── Attention heads (infer from Q/K/V shapes) ──
for k, t in state_dict.items():
lower_k = k.lower()
if "q_proj" in lower_k or "query" in lower_k:
if t.ndim == 2:
out_dim = t.shape[0]
hidden = config.get("likely_hidden_size", t.shape[1])
if hidden > 0:
possible_heads = [h for h in [1, 2, 4, 6, 8, 12, 16, 24, 32, 40, 48, 64, 96, 128]
if hidden % h == 0 and (hidden // h) in [32, 64, 80, 96, 128, 256]]
if possible_heads:
config["num_attention_heads"] = possible_heads[-1]
config["head_dim"] = hidden // possible_heads[-1]
break
# ── Intermediate / FFN size ──
for k, t in state_dict.items():
lower_k = k.lower()
if any(x in lower_k for x in ["intermediate.dense", "mlp.fc1", "mlp.c_fc",
"ffn.0", "gate_proj", "up_proj", "wi_0", "fc1"]):
if t.ndim == 2:
config["intermediate_size"] = t.shape[0]
break
# ── Max sequence length (from positional embeddings) ──
for k, t in state_dict.items():
lower_k = k.lower()
if any(x in lower_k for x in ["position_embed", "pos_embed", "wpe", "position_ids",
"positional_embedding"]):
if t.ndim == 2:
config["max_position_embeddings"] = t.shape[0]
break
# ── Number of classes (from final classifier) ──
for k in reversed(keys):
t = state_dict[k]
lower_k = k.lower()
if any(x in lower_k for x in ["classifier", "lm_head", "cls.predictions",
"qa_outputs", "score", "fc_out"]):
if t.ndim == 2:
config["num_classes_or_vocab_output"] = t.shape[0]
break
elif t.ndim == 1:
config["num_classes_or_vocab_output"] = t.shape[0]
break
# ── GQA: num key-value heads ──
q_size = None
kv_size = None
for k, t in state_dict.items():
if "q_proj.weight" in k and t.ndim == 2:
q_size = t.shape[0]
if "k_proj.weight" in k and t.ndim == 2:
kv_size = t.shape[0]
if q_size and kv_size and q_size != kv_size:
config["grouped_query_attention"] = True
config["num_kv_heads"] = kv_size // (q_size // (config.get("num_attention_heads", q_size)))
return config
def format_detection_report(detections: list, config: dict) -> str:
"""Format the architecture detection results as a readable report."""
if not detections:
return ("πŸ” ARCHITECTURE DETECTION\n\n"
" No known architecture patterns matched.\n"
" This may be a custom model or an unusual format.\n")
text = "🧬 ARCHITECTURE DETECTION REPORT\n"
text += "=" * 65 + "\n\n"
# Primary detection
primary = detections[0]
text += f"πŸ† PRIMARY MATCH: {primary['family']}\n"
text += f" Category: {primary['category']}\n"
text += f" Description: {primary['description']}\n"
text += f" Confidence: {primary['confidence']}\n"
text += f" Evidence: {primary['matched_keys_count']} matching parameters "
text += f"({primary['hit_ratio']:.1%} of model)\n\n"
# All detections
if len(detections) > 1:
text += "πŸ“‹ ALL DETECTED PATTERNS:\n"
text += f"{'#':<4} {'Confidence':<14} {'Family':<35} {'Category':<25} {'Hits':>5}\n"
text += "─" * 90 + "\n"
for i, d in enumerate(detections, 1):
text += (f"{i:<4} {d['confidence']:<14} {d['family']:<35} "
f"{d['category']:<25} {d['matched_keys_count']:>5}\n")
text += "\n"
# Inferred config
if config:
text += "βš™οΈ INFERRED MODEL CONFIGURATION:\n"
text += "─" * 45 + "\n"
label_map = {
"likely_hidden_size": "Hidden Size",
"num_layers": "Number of Layers",
"vocab_size": "Vocabulary Size",
"embedding_dim": "Embedding Dimension",
"num_attention_heads": "Attention Heads",
"head_dim": "Head Dimension",
"intermediate_size": "FFN Intermediate Size",
"max_position_embeddings": "Max Sequence Length",
"num_classes_or_vocab_output": "Output Classes / Vocab",
"grouped_query_attention": "Grouped-Query Attention",
"num_kv_heads": "KV Heads (GQA)",
}
for key, val in config.items():
if key == "common_dimensions":
continue
label = label_map.get(key, key)
if isinstance(val, int) and val > 1000:
text += f" β€’ {label}: {val:,}\n"
else:
text += f" β€’ {label}: {val}\n"
if "common_dimensions" in config:
dims = config["common_dimensions"][:8]
text += f"\n πŸ“ Most common tensor dimensions:\n"
for dim, count in dims:
text += f" {dim:>8,} (appears {count} times)\n"
# Category summary
categories = list(set(d["category"] for d in detections))
if categories:
text += f"\n🏷️ MODEL CATEGORIES: {', '.join(categories)}\n"
return text