File size: 895 Bytes
cce6a45
 
 
 
 
 
 
 
 
 
3dd836d
cce6a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dd836d
 
cce6a45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin


class MyModel(
    nn.Module,
    PyTorchModelHubMixin, 
    # optionally, you can add metadata which gets pushed to the model card
    repo_url="https://huggingface.co/Robzy/random-genre",
    pipeline_tag="audio-classification",
    license="mit",
):
    def __init__(self, num_channels: int, hidden_size: int, num_classes: int):
        super().__init__()
        self.param = nn.Parameter(torch.rand(num_channels, hidden_size))
        self.linear = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        return self.linear(x + self.param)

# create model
config = {"num_channels": 3, "hidden_size": 32, "num_classes": 10}
model = MyModel(**config)

# Save the model locally
# model_save_path = "trial-model"
# model.save_pretrained(model_save_path)
model.push_to_hub("Robzy/random-genre")