import gradio as gr import torch import torch.nn as nn import math import pickle import json from huggingface_hub import hf_hub_download REPO_ID = "itriedcoding/Sage" # Custom model class matching Sage architecture class TransformerLM(nn.Module): def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_embedding = nn.Embedding(max_seq_length, d_model) encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, dropout=0.1) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output_layer = nn.Linear(d_model, vocab_size) self.max_seq_length = max_seq_length self.vocab_size = vocab_size def forward(self, src): seq_len = src.size(1) pos = torch.arange(0, seq_len, device=src.device).unsqueeze(0) src_emb = self.embedding(src) * math.sqrt(self.embedding.embedding_dim) pos_emb = self.pos_embedding(pos) src_emb = src_emb + pos_emb output = self.transformer_encoder(src_emb) logits = self.output_layer(output) return logits # Download model files from Hugging Face print("Downloading model files...") config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") state_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model_state.bin") tok_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.pkl") # Load config with open(config_path) as f: config = json.load(f) # Load tokenizer with open(tok_path, 'rb') as f: tokenizer = pickle.load(f) # Load model model = TransformerLM( vocab_size=config['vocab_size'], d_model=config['hidden_size'], nhead=config['num_attention_heads'], num_layers=config['num_hidden_layers'], dim_feedforward=config['intermediate_size'], max_seq_length=config['max_position_embeddings'] ) state_dict = torch.load(state_path, map_location='cpu', weights_only=True) model.load_state_dict(state_dict, strict=False) model.eval() def generate_text(prompt, max_length, temperature): input_ids = tokenizer.encode(prompt, max_length=32, padding=False, truncation=False, return_tensors='pt') generated = input_ids.clone() with torch.no_grad(): for _ in range(int(max_length)): logits = model(generated) next_logits = logits[0, -1, :] / temperature probs = torch.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1) if next_token.item() == tokenizer.char_to_idx.get('.', 0): break return tokenizer.decode(generated[0]) demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Prompt", value="Hello", placeholder="Enter your prompt here"), gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Max Length"), gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") ], outputs=gr.Textbox(label="Generated Text"), title="Sage Text Generator", description="Custom character-level language model built from scratch with PyTorch.", examples=[ ["Hello", 30, 0.8], ["The weather", 30, 0.8], ["Deep learning", 30, 0.8] ] ) if __name__ == "__main__": demo.launch()