| # my_custom_olmoe/configuration_custom.py | |
| # 注意:根据你的 transformers 版本,导入官方 OLMoE 配置的路径可能需要调整 | |
| from transformers.models.olmoe.configuration_olmoe import OlmoeConfig | |
| class DenseBackwardOLMoEConfig(OlmoeConfig): | |
| model_type = "DenseBackward_olmoe" # 这里覆盖 model_type 字段,便于后续识别 | |
| # 添加auto_map用于支持AutoClass | |
| auto_map = { | |
| "AutoConfig": "configuration_custom.DenseBackwardOLMoEConfig", | |
| "AutoModelForCausalLM": "modeling_custom.DenseBackwardOLMoEForCausalLM" | |
| } | |
| def __init__(self, model_marker="DenseBackward_olmoe_marker", **kwargs): | |
| super().__init__(**kwargs) | |
| self.model_marker = model_marker | |
| self.intermediate_size= 1024 | |
| self.torch_dtype= "bfloat16" | |
| #test | |
| def main(): | |
| config = DenseBackwardOLMoEConfig(model_marker="DenseBackward_olmoe_marker", | |
| torch_dtype="bfloat16") | |
| print(config) | |
| if __name__ == "__main__": | |
| main() |