Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import math | |
| import pickle | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| REPO_ID = "itriedcoding/Sage" | |
| class CharacterTokenizer: | |
| def __init__(self): | |
| self.char_to_idx = {} | |
| self.idx_to_char = {} | |
| self.vocab_size = 0 | |
| self.pad_token_id = 0 | |
| self.unk_token_id = 1 | |
| def fit(self, texts): | |
| chars = set() | |
| for text in texts: | |
| chars.update(list(str(text))) | |
| self.char_to_idx['<PAD>'] = 0 | |
| self.char_to_idx['<UNK>'] = 1 | |
| for i, char in enumerate(sorted(chars)): | |
| self.char_to_idx[char] = i + 2 | |
| self.idx_to_char = {v: k for k, v in self.char_to_idx.items()} | |
| self.vocab_size = len(self.char_to_idx) | |
| def encode(self, text, max_length=None, padding=False, truncation=False, return_tensors=None): | |
| if isinstance(text, str): | |
| text = [text] | |
| encoded = [] | |
| for t in text: | |
| tokens = [self.char_to_idx.get(c, self.unk_token_id) for c in str(t)] | |
| if truncation and max_length: | |
| tokens = tokens[:max_length] | |
| if padding and max_length: | |
| tokens = tokens + [self.pad_token_id] * (max_length - len(tokens)) | |
| encoded.append(tokens) | |
| if return_tensors == 'pt': | |
| return torch.tensor(encoded, dtype=torch.long) | |
| return encoded | |
| def decode(self, token_ids): | |
| if isinstance(token_ids, torch.Tensor): | |
| token_ids = token_ids.tolist() | |
| chars = [self.idx_to_char.get(idx, '<UNK>') for idx in token_ids] | |
| return ''.join(chars) | |
| # Custom model class matching Sage architecture | |
| class TransformerLM(nn.Module): | |
| def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, max_seq_length=64): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, d_model) | |
| self.pos_embedding = nn.Embedding(max_seq_length, d_model) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, dropout=0.1) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.output_layer = nn.Linear(d_model, vocab_size) | |
| self.max_seq_length = max_seq_length | |
| self.vocab_size = vocab_size | |
| def forward(self, src): | |
| seq_len = src.size(1) | |
| pos = torch.arange(0, seq_len, device=src.device).unsqueeze(0) | |
| src_emb = self.embedding(src) * math.sqrt(self.embedding.embedding_dim) | |
| pos_emb = self.pos_embedding(pos) | |
| src_emb = src_emb + pos_emb | |
| output = self.transformer_encoder(src_emb) | |
| logits = self.output_layer(output) | |
| return logits | |
| # Download model files from Hugging Face | |
| print("Downloading model files...") | |
| config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") | |
| state_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model_state.bin") | |
| tok_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.pkl") | |
| # Load config | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| # Load tokenizer | |
| with open(tok_path, 'rb') as f: | |
| tokenizer = pickle.load(f) | |
| # Load model | |
| model = TransformerLM( | |
| vocab_size=config['vocab_size'], | |
| d_model=config['hidden_size'], | |
| nhead=config['num_attention_heads'], | |
| num_layers=config['num_hidden_layers'], | |
| dim_feedforward=config['intermediate_size'], | |
| max_seq_length=config['max_position_embeddings'] | |
| ) | |
| state_dict = torch.load(state_path, map_location='cpu', weights_only=True) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| def generate_text(prompt, max_length, temperature): | |
| input_ids = tokenizer.encode(prompt, max_length=32, padding=False, truncation=False, return_tensors='pt') | |
| generated = input_ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(int(max_length)): | |
| logits = model(generated) | |
| next_logits = logits[0, -1, :] / temperature | |
| probs = torch.softmax(next_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1) | |
| if next_token.item() == tokenizer.char_to_idx.get('.', 0): | |
| break | |
| return tokenizer.decode(generated[0]) | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", value="Hello", placeholder="Enter your prompt here"), | |
| gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Max Length"), | |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text"), | |
| title="Sage Text Generator", | |
| description="Custom character-level language model built from scratch with PyTorch.", | |
| examples=[ | |
| ["Hello", 30, 0.8], | |
| ["The weather", 30, 0.8], | |
| ["Deep learning", 30, 0.8] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |