File size: 1,202 Bytes
2f63422
 
 
 
 
 
 
 
 
6c72502
 
2f63422
ce6dbe3
2f63422
6c72502
 
 
7095bf6
6c72502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
---
language:
- en
pipeline_tag: text-generation
tags:
- distillation
- model_hub_mixin
- pytorch_model_hub_mixin
- simple-stories
datasets:
- lennart-finke/SimpleStories
---
For loading this model from within [https://github.com/danbraunai/simple_stories_train](https://github.com/danbraunai/simple_stories_train), you can run:

```python
from typing import Any

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from simple_stories_train.models.llama import Llama, LlamaConfig
from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT

class LlamaTransformer(
    nn.Module,
    PyTorchModelHubMixin, 
    repo_url="https://github.com/danbraunai/simple_stories_train",
    language=["en"],
    pipeline_tag="text-generation"
):
    def __init__(self, **config : Any):
        super().__init__()
        self.llama = Llama(LlamaConfig(**config))

    def forward(self, x : torch.Tensor):
        return self.llama(x)

config = MODEL_CONFIGS_DICT["d12"]
model = LlamaTransformer(**config)
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"

model = model.from_pretrained(HUB_REPO_NAME)
```

- Library: https://github.com/danbraunai/simple_stories_train