sage-space / app.py
itriedcoding's picture
Upload app.py with huggingface_hub
f8e9a1d verified
Raw
History Blame Contribute Delete
5.03 kB
import gradio as gr
import torch
import torch.nn as nn
import math
import pickle
import json
from huggingface_hub import hf_hub_download
REPO_ID = "itriedcoding/Sage"
class CharacterTokenizer:
def __init__(self):
self.char_to_idx = {}
self.idx_to_char = {}
self.vocab_size = 0
self.pad_token_id = 0
self.unk_token_id = 1
def fit(self, texts):
chars = set()
for text in texts:
chars.update(list(str(text)))
self.char_to_idx['<PAD>'] = 0
self.char_to_idx['<UNK>'] = 1
for i, char in enumerate(sorted(chars)):
self.char_to_idx[char] = i + 2
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
self.vocab_size = len(self.char_to_idx)
def encode(self, text, max_length=None, padding=False, truncation=False, return_tensors=None):
if isinstance(text, str):
text = [text]
encoded = []
for t in text:
tokens = [self.char_to_idx.get(c, self.unk_token_id) for c in str(t)]
if truncation and max_length:
tokens = tokens[:max_length]
if padding and max_length:
tokens = tokens + [self.pad_token_id] * (max_length - len(tokens))
encoded.append(tokens)
if return_tensors == 'pt':
return torch.tensor(encoded, dtype=torch.long)
return encoded
def decode(self, token_ids):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
chars = [self.idx_to_char.get(idx, '<UNK>') for idx in token_ids]
return ''.join(chars)
# Custom model class matching Sage architecture
class TransformerLM(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_length, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, dropout=0.1)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
self.max_seq_length = max_seq_length
self.vocab_size = vocab_size
def forward(self, src):
seq_len = src.size(1)
pos = torch.arange(0, seq_len, device=src.device).unsqueeze(0)
src_emb = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
pos_emb = self.pos_embedding(pos)
src_emb = src_emb + pos_emb
output = self.transformer_encoder(src_emb)
logits = self.output_layer(output)
return logits
# Download model files from Hugging Face
print("Downloading model files...")
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
state_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model_state.bin")
tok_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.pkl")
# Load config
with open(config_path) as f:
config = json.load(f)
# Load tokenizer
with open(tok_path, 'rb') as f:
tokenizer = pickle.load(f)
# Load model
model = TransformerLM(
vocab_size=config['vocab_size'],
d_model=config['hidden_size'],
nhead=config['num_attention_heads'],
num_layers=config['num_hidden_layers'],
dim_feedforward=config['intermediate_size'],
max_seq_length=config['max_position_embeddings']
)
state_dict = torch.load(state_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict, strict=False)
model.eval()
def generate_text(prompt, max_length, temperature):
input_ids = tokenizer.encode(prompt, max_length=32, padding=False, truncation=False, return_tensors='pt')
generated = input_ids.clone()
with torch.no_grad():
for _ in range(int(max_length)):
logits = model(generated)
next_logits = logits[0, -1, :] / temperature
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
if next_token.item() == tokenizer.char_to_idx.get('.', 0):
break
return tokenizer.decode(generated[0])
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", value="Hello", placeholder="Enter your prompt here"),
gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Max Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text"),
title="Sage Text Generator",
description="Custom character-level language model built from scratch with PyTorch.",
examples=[
["Hello", 30, 0.8],
["The weather", 30, 0.8],
["Deep learning", 30, 0.8]
]
)
if __name__ == "__main__":
demo.launch()