Sage / space_app.py
itriedcoding's picture
Upload folder using huggingface_hub
9f23f3f verified
Raw
History Blame Contribute Delete
3.55 kB
import gradio as gr
import torch
import torch.nn as nn
import math
import pickle
import json
from huggingface_hub import hf_hub_download
REPO_ID = "itriedcoding/Sage"
# Custom model class matching Sage architecture
class TransformerLM(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_length, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, dropout=0.1)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
self.max_seq_length = max_seq_length
self.vocab_size = vocab_size
def forward(self, src):
seq_len = src.size(1)
pos = torch.arange(0, seq_len, device=src.device).unsqueeze(0)
src_emb = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
pos_emb = self.pos_embedding(pos)
src_emb = src_emb + pos_emb
output = self.transformer_encoder(src_emb)
logits = self.output_layer(output)
return logits
# Download model files from Hugging Face
print("Downloading model files...")
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
state_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model_state.bin")
tok_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.pkl")
# Load config
with open(config_path) as f:
config = json.load(f)
# Load tokenizer
with open(tok_path, 'rb') as f:
tokenizer = pickle.load(f)
# Load model
model = TransformerLM(
vocab_size=config['vocab_size'],
d_model=config['hidden_size'],
nhead=config['num_attention_heads'],
num_layers=config['num_hidden_layers'],
dim_feedforward=config['intermediate_size'],
max_seq_length=config['max_position_embeddings']
)
state_dict = torch.load(state_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict, strict=False)
model.eval()
def generate_text(prompt, max_length, temperature):
input_ids = tokenizer.encode(prompt, max_length=32, padding=False, truncation=False, return_tensors='pt')
generated = input_ids.clone()
with torch.no_grad():
for _ in range(int(max_length)):
logits = model(generated)
next_logits = logits[0, -1, :] / temperature
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
if next_token.item() == tokenizer.char_to_idx.get('.', 0):
break
return tokenizer.decode(generated[0])
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", value="Hello", placeholder="Enter your prompt here"),
gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Max Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text"),
title="Sage Text Generator",
description="Custom character-level language model built from scratch with PyTorch.",
examples=[
["Hello", 30, 0.8],
["The weather", 30, 0.8],
["Deep learning", 30, 0.8]
]
)
if __name__ == "__main__":
demo.launch()