Codes to create this repo:
import os
import torch
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, set_seed)
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
def create_moe_config(source_model_id, num_experts=3, num_experts_per_tok=1):
source_config = AutoConfig.from_pretrained(source_model_id, trust_remote_code=True)
target_config = Qwen3MoeConfig()
target_config.architectures = ["Qwen3MoeForCausalLM"]
target_config.model_type = "qwen3_moe"
target_config._name_or_path = source_model_id
target_config.num_experts = num_experts
target_config.num_experts_per_tok = num_experts_per_tok
target_config.norm_topk_prob = True
target_config.router_aux_loss_coef = 0.001
target_config.output_router_logits = False
target_config.vocab_size = source_config.vocab_size
target_config.hidden_size = source_config.hidden_size
target_config.intermediate_size = source_config.intermediate_size
target_config.moe_intermediate_size = source_config.intermediate_size
target_config.num_hidden_layers = source_config.num_hidden_layers
target_config.num_attention_heads = source_config.num_attention_heads
target_config.num_key_value_heads = source_config.num_key_value_heads
target_config.max_position_embeddings = source_config.max_position_embeddings
target_config.max_window_layers = source_config.max_window_layers
target_config.head_dim = source_config.head_dim
target_config.tie_word_embeddings = source_config.tie_word_embeddings
return target_config
def copy_weights(source_model, target_model, seed=42):
set_seed(seed)
with torch.no_grad():
for name, param in target_model.named_parameters():
if "router" in name:
torch.nn.init.normal_(param, 0, 0.02)
print(f"randomly initialized: {name}, shape: {param.shape}")
continue
if "expert" in name:
parts = name.split(".")
expert_idx = None
for i, part in enumerate(parts):
if part == "experts":
expert_idx = int(parts[i + 1])
break
if expert_idx is not None:
source_name = name.replace(f".experts.{expert_idx}", "")
try:
source_param = source_model.get_parameter(source_name)
param.copy_(source_param)
print(f"copied from source: {name} <- {source_name}, shape: {param.shape}")
except:
torch.nn.init.normal_(param, 0, 0.02)
print(f"randomly initialized (not found in source): {name}, shape: {param.shape}")
else:
torch.nn.init.normal_(param, 0, 0.02)
print(f"randomly initialized (cannot parse expert): {name}, shape: {param.shape}")
else:
try:
source_param = source_model.get_parameter(name)
param.copy_(source_param)
print(f"copied from source: {name}, shape: {param.shape}")
except:
torch.nn.init.normal_(param, 0, 0.02)
print(f"randomly initialized (not found in source): {name}, shape: {param.shape}")
def create_moe_model(source_model_id, target_model_id, num_experts=3, num_experts_per_tok=1, seed=42):
os.makedirs(target_model_id, exist_ok=True)
source_tokenizer = AutoTokenizer.from_pretrained(source_model_id, trust_remote_code=True)
source_tokenizer.save_pretrained(target_model_id)
target_config = create_moe_config(source_model_id, num_experts, num_experts_per_tok)
target_model = AutoModelForCausalLM.from_config(target_config, torch_dtype=torch.bfloat16, trust_remote_code=True)
source_model = AutoModelForCausalLM.from_pretrained(source_model_id, torch_dtype=torch.bfloat16, trust_remote_code=True)
target_model.generation_config = GenerationConfig.from_pretrained(source_model_id, trust_remote_code=True)
copy_weights(source_model, target_model, seed)
target_model.save_pretrained(target_model_id)
print(f"model saved to {target_model_id}")
def valid_moe_model(source_model_id, target_model_id):
import torch
from transformers import pipeline, AutoTokenizer
def get_output(model_path, prompt):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
generator = pipeline("text-generation", model=model_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
outputs = generator(prompt, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, eos_token_id=tokenizer.eos_token_id)
generated_text = outputs[0]['generated_text']
return generated_text[len(prompt):].strip() if generated_text.startswith(prompt) else generated_text
prompt = "请介绍一下人工智能的发展历史。"
print("source model output:")
print(get_output(target_model_id, prompt))
print("target model output:")
print(get_output(source_model_id, prompt))
if __name__ == "__main__":
source_model_id = "./Qwen3-0.6B"
target_model_id = "./Qwen3-2B-A0.6B"
create_moe_model(source_model_id=source_model_id, target_model_id=target_model_id, num_experts=3, num_experts_per_tok=1, seed=42)
valid_moe_model(source_model_id=source_model_id, target_model_id=target_model_id)
- Downloads last month
- 6
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support