model_structure_viewer / backend /hf_model_utils.py
maomao88's picture
add support for vlm
24b8880
import torch
import torch.nn as nn
import json
import hashlib
import gc
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForImageClassification,
AutoModelForImageTextToText
)
from accelerate import init_empty_weights
model_structure_cache = {}
def module_hash(module):
"""
Generate a hash representing the structure of a module.
Uses class name + child hashes + param shapes to detect repeats.
"""
children = list(module.named_children())
child_hashes = []
for name, child in children:
child_hashes.append(module_hash(child))
# Include class name and param shapes
param_info = [(name, tuple(p.shape), p.requires_grad)
for name, p in module.named_parameters(recurse=False)]
rep = (module.__class__.__name__, tuple(child_hashes), tuple(param_info))
rep_bytes = str(rep).encode('utf-8')
return hashlib.md5(rep_bytes).hexdigest()
def is_number_string(value):
return isinstance(value, str) and value.isdigit()
def hf_style_structural_dict(module):
"""
Recursively convert a PyTorch module into a dict mirroring
Hugging Face's print(model), only counting repeats when structure is identical.
"""
children = list(module.named_children())
result = {"class_name": module.__class__.__name__}
# Include params if present
params = {name: {"shape": list(p.shape), "requires_grad": p.requires_grad}
for name, p in module.named_parameters(recurse=False)}
if params:
result["params"] = params
if children:
child_dict = {}
i = 0
while i < len(children):
name, child = children[i]
current_hash = module_hash(child)
count = 1
j = i + 1
# Count consecutive children that are structurally identical
while j < len(children) and is_number_string(name) and module_hash(children[j][1]) == current_hash:
count += 1
j += 1
child_entry = hf_style_structural_dict(child)
if count > 1:
child_entry["num_repeats"] = count
child_dict[name] = child_entry
i += count
result["children"] = child_dict
return result
def get_model_structure(model_name: str, model_type: str | None):
# 1. Check if it's already cached
if model_name in model_structure_cache:
return model_structure_cache[model_name]
print(model_type)
# 2. If not cached, build the structure
if model_type == "causal":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
elif model_type == "masked":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForMaskedLM.from_config(config)
elif model_type == "sequence":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForSequenceClassification.from_config(config)
elif model_type == "token":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForTokenClassification.from_config(config)
elif model_type == "qa":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForQuestionAnswering.from_config(config)
elif model_type == "s2s":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
elif model_type == "vision":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForImageClassification.from_config(config)
elif model_type == "vlm":
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForImageTextToText.from_config(config, trust_remote_code=True)
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if hasattr(config, "vision_config"):
# It's a VLM
with init_empty_weights():
model = AutoModelForImageTextToText.from_config(config, trust_remote_code=True)
else:
# It's a standard model
with init_empty_weights():
model = AutoModel.from_config(config, trust_remote_code=True)
structure = {
"model_type": config.model_type,
"hidden_size": getattr(config, "hidden_size", None),
"num_hidden_layers": getattr(config, "num_hidden_layers", None),
"num_attention_heads": getattr(config, "num_attention_heads", None),
"image_size": getattr(config, "image_size", None),
"intermediate_size": getattr(config, "intermediate_size", None),
"patch_size": getattr(config, "patch_size", None),
"vocab_size": getattr(config, "vocab_size", None),
"layers": hf_style_structural_dict(model)
}
# 3. Free memory
del model
gc.collect()
torch.cuda.empty_cache() # only if using GPU
# 4. Save JSON in cache
model_structure_cache[model_name] = structure
return structure # JSON-serializable