hello-world / index.py
vinothkannans's picture
Initial commit
d990db2 verified
from transformers import PreTrainedModel, PretrainedConfig
import torch.nn as nn
import torch
class HelloWorldConfig(PretrainedConfig):
model_type = "hello-world"
class HelloWorldModel(PreTrainedModel):
config_class = HelloWorldConfig
def __init__(self, config):
super().__init__(config)
# A single linear layer just for demo
self.layer = nn.Linear(1, 1)
def forward(self, input_ids=None):
# Always returns “Hello World”
return {"text": "Hello World"}
# Create the model + config
config = HelloWorldConfig()
model = HelloWorldModel(config)
# Save to local folder
model.save_pretrained("./")
config.save_pretrained("./")