| #!/usr/bin/env python3 | |
| """ | |
| Register custom model architectures for LLaMA-Omni2 models. | |
| Run this before starting the model worker. | |
| """ | |
| from transformers import AutoConfig, AutoModelForCausalLM | |
| from llama_omni2.model.language_model.omni2_speech2s_qwen2 import ( | |
| Omni2Speech2SQwen2ForCausalLM, | |
| Omni2Speech2SConfig | |
| ) | |
| # Register for Qwen2-based model | |
| AutoConfig.register("omni2_speech2s_qwen2", Omni2Speech2SConfig) | |
| AutoModelForCausalLM.register(Omni2Speech2SConfig, Omni2Speech2SQwen2ForCausalLM) | |
| # Also register for llama model type (for compatibility) | |
| class Omni2Speech2SLlamaConfig(Omni2Speech2SConfig): | |
| model_type = "omni_speech2s_llama" | |
| AutoConfig.register("omni_speech2s_llama", Omni2Speech2SLlamaConfig) | |
| AutoModelForCausalLM.register(Omni2Speech2SLlamaConfig, Omni2Speech2SQwen2ForCausalLM) | |
| print("✓ Custom model architectures registered successfully") | |
| print(" - omni2_speech2s_qwen2") | |
| print(" - omni_speech2s_llama") |