lennart-finke commited on
Commit
6c72502
·
verified ·
1 Parent(s): bcc051c

Extended readme

Browse files
Files changed (1) hide show
  1. README.md +34 -3
README.md CHANGED
@@ -7,8 +7,39 @@ tags:
7
  - model_hub_mixin
8
  - pytorch_model_hub_mixin
9
  - simple-stories
 
 
10
  ---
 
11
 
12
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
13
- - Library: https://github.com/danbraunai/simple_stories_train
14
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  - model_hub_mixin
8
  - pytorch_model_hub_mixin
9
  - simple-stories
10
+ datasets:
11
+ - lennart-finke/SimpleStories
12
  ---
13
+ For loading this model from within [https://github.com/danbraunai/simple_stories_train](), you can run:
14
 
15
+ ```python
16
+ from typing import Any
17
+
18
+ import torch.nn as nn
19
+ from huggingface_hub import PyTorchModelHubMixin
20
+
21
+ from simple_stories_train.models.llama import Llama, LlamaConfig
22
+ from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT
23
+
24
+ class LlamaTransformer(
25
+ nn.Module,
26
+ PyTorchModelHubMixin,
27
+ repo_url="https://github.com/danbraunai/simple_stories_train",
28
+ language=["en"],
29
+ pipeline_tag="text-generation"
30
+ ):
31
+ def __init__(self, **config : Any):
32
+ super().__init__()
33
+ self.llama = Llama(LlamaConfig(**config))
34
+
35
+ def forward(self, x : torch.Tensor):
36
+ return self.llama(x)
37
+
38
+ config = MODEL_CONFIGS_DICT["d12"]
39
+ model = LlamaTransformer(**config)
40
+ HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"
41
+
42
+ model = model.from_pretrained(HUB_REPO_NAME)
43
+ ```
44
+
45
+ - Library: https://github.com/danbraunai/simple_stories_train