File size: 941 Bytes
778d4b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#!/usr/bin/env python3
"""
Register custom model architectures for LLaMA-Omni2 models.
Run this before starting the model worker.
"""

from transformers import AutoConfig, AutoModelForCausalLM
from llama_omni2.model.language_model.omni2_speech2s_qwen2 import (
    Omni2Speech2SQwen2ForCausalLM,
    Omni2Speech2SConfig
)

# Register for Qwen2-based model
AutoConfig.register("omni2_speech2s_qwen2", Omni2Speech2SConfig)
AutoModelForCausalLM.register(Omni2Speech2SConfig, Omni2Speech2SQwen2ForCausalLM)

# Also register for llama model type (for compatibility)
class Omni2Speech2SLlamaConfig(Omni2Speech2SConfig):
    model_type = "omni_speech2s_llama"

AutoConfig.register("omni_speech2s_llama", Omni2Speech2SLlamaConfig)
AutoModelForCausalLM.register(Omni2Speech2SLlamaConfig, Omni2Speech2SQwen2ForCausalLM)

print("✓ Custom model architectures registered successfully")
print("  - omni2_speech2s_qwen2")
print("  - omni_speech2s_llama")