Charlie81 commited on
Commit
78b85e8
·
1 Parent(s): c4785c5

fix import

Browse files
Files changed (2) hide show
  1. myolmoe/__init__.py +13 -1
  2. scripts/train.py +8 -2
myolmoe/__init__.py CHANGED
@@ -1 +1,13 @@
1
- from .modeling_myolmoe import MyOlmoeForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modeling_myolmoe import (
2
+ MyOlmoeForCausalLM,
3
+ OlmoeModel,
4
+ OlmoePreTrainedModel,
5
+ OlmoeConfig # Add this line
6
+ )
7
+
8
+ __all__ = [
9
+ "MyOlmoeForCausalLM",
10
+ "OlmoeModel",
11
+ "OlmoePreTrainedModel",
12
+ "OlmoeConfig" # Add this line
13
+ ]
scripts/train.py CHANGED
@@ -12,8 +12,14 @@ from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
12
  import os
13
 
14
  def main():
15
- # Load config and model
16
- config = OlmoeConfig.from_pretrained("myolmoe/config.json")
 
 
 
 
 
 
17
  model = MyOlmoeForCausalLM.from_pretrained(
18
  "myolmoe",
19
  config=config,
 
12
  import os
13
 
14
  def main():
15
+ # Load config - first try from local file, then from pretrained
16
+ config_path = os.path.join("myolmoe", "config.json")
17
+ if os.path.exists(config_path):
18
+ config = OlmoeConfig.from_json_file(config_path)
19
+ else:
20
+ config = OlmoeConfig.from_pretrained("myolmoe")
21
+
22
+ # Load model
23
  model = MyOlmoeForCausalLM.from_pretrained(
24
  "myolmoe",
25
  config=config,