File size: 3,547 Bytes
9f23f3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
import torch.nn as nn
import math
import pickle
import json
from huggingface_hub import hf_hub_download

REPO_ID = "itriedcoding/Sage"

# Custom model class matching Sage architecture
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, dropout=0.1)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.max_seq_length = max_seq_length
        self.vocab_size = vocab_size

    def forward(self, src):
        seq_len = src.size(1)
        pos = torch.arange(0, seq_len, device=src.device).unsqueeze(0)
        src_emb = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        pos_emb = self.pos_embedding(pos)
        src_emb = src_emb + pos_emb
        output = self.transformer_encoder(src_emb)
        logits = self.output_layer(output)
        return logits

# Download model files from Hugging Face
print("Downloading model files...")
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
state_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model_state.bin")
tok_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.pkl")

# Load config
with open(config_path) as f:
    config = json.load(f)

# Load tokenizer
with open(tok_path, 'rb') as f:
    tokenizer = pickle.load(f)

# Load model
model = TransformerLM(
    vocab_size=config['vocab_size'],
    d_model=config['hidden_size'],
    nhead=config['num_attention_heads'],
    num_layers=config['num_hidden_layers'],
    dim_feedforward=config['intermediate_size'],
    max_seq_length=config['max_position_embeddings']
)
state_dict = torch.load(state_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict, strict=False)
model.eval()

def generate_text(prompt, max_length, temperature):
    input_ids = tokenizer.encode(prompt, max_length=32, padding=False, truncation=False, return_tensors='pt')
    generated = input_ids.clone()
    
    with torch.no_grad():
        for _ in range(int(max_length)):
            logits = model(generated)
            next_logits = logits[0, -1, :] / temperature
            probs = torch.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
            if next_token.item() == tokenizer.char_to_idx.get('.', 0):
                break
    
    return tokenizer.decode(generated[0])

demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", value="Hello", placeholder="Enter your prompt here"),
        gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
    ],
    outputs=gr.Textbox(label="Generated Text"),
    title="Sage Text Generator",
    description="Custom character-level language model built from scratch with PyTorch.",
    examples=[
        ["Hello", 30, 0.8],
        ["The weather", 30, 0.8],
        ["Deep learning", 30, 0.8]
    ]
)

if __name__ == "__main__":
    demo.launch()