Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import json | |
| from transformers import GPT2Tokenizer | |
| from safetensors.torch import load_file | |
| from transformers import GPT2Config as GPTConfig | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from dataclasses import dataclass | |
| # Define the GPTConfig class with filtering | |
| class GPTConfig: | |
| def __init__(self, n_embd, n_head, n_layer, vocab_size): | |
| self.n_embd = n_embd | |
| self.n_head = n_head | |
| self.n_layer = n_layer | |
| self.vocab_size = vocab_size | |
| def from_dict(cls, config_dict): | |
| # Define the expected keys | |
| expected_keys = {'n_embd', 'n_head', 'n_layer', 'vocab_size'} | |
| # Filter out unexpected keys | |
| filtered_dict = {key: value for key, value in config_dict.items() if key in expected_keys} | |
| return cls(**filtered_dict) | |
| # Define the GPT class | |
| class GPT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| # Initialize the embedding layer | |
| self.embedding = nn.Embedding(config.vocab_size, config.n_embd) | |
| # Initialize the Transformer decoder | |
| decoder_layer = nn.TransformerDecoderLayer(d_model=config.n_embd, nhead=config.n_head, dim_feedforward=config.n_embd, dropout=0.1) | |
| self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=config.n_layer) | |
| # Initialize the language model head | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size) | |
| def forward(self, input_ids): | |
| # Embed the input tokens | |
| input_embeddings = self.embedding(input_ids) | |
| # Transpose the input to match the expected shape for TransformerDecoder | |
| input_embeddings = input_embeddings.transpose(0, 1) | |
| # Pass through the Transformer decoder | |
| transformer_output = self.transformer(input_embeddings, input_embeddings) | |
| # Transpose back to the original shape | |
| transformer_output = transformer_output.transpose(0, 1) | |
| # Get the logits from the language model head | |
| logits = self.lm_head(transformer_output) | |
| return logits | |
| def generate(self, input_ids, max_new_tokens, temperature, top_k): | |
| # Implement the text generation logic | |
| output_ids = input_ids | |
| for _ in range(max_new_tokens): | |
| logits = self.forward(output_ids[:, -1:]) | |
| logits = logits / temperature | |
| probs = F.softmax(logits, dim=-1) | |
| # Ensure probs is 2D | |
| if probs.dim() == 3: | |
| probs = probs.squeeze(0) # Remove the batch dimension if it exists | |
| top_k_probs, top_k_indices = torch.topk(probs, k=top_k) | |
| # Ensure top_k_probs is 2D | |
| if top_k_probs.dim() == 1: | |
| top_k_probs = top_k_probs.unsqueeze(0) | |
| next_token = torch.multinomial(top_k_probs, num_samples=1) | |
| next_token = top_k_indices.gather(-1, next_token) | |
| # Ensure next_token is 2D | |
| if next_token.dim() == 1: | |
| next_token = next_token.unsqueeze(0) | |
| output_ids = torch.cat([output_ids, next_token], dim=1) | |
| return output_ids | |
| # Initialize global variables | |
| model = None | |
| tokenizer = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| """Load the Leap0 model and tokenizer.""" | |
| global model, tokenizer | |
| try: | |
| # Paths to config and model files | |
| config_path = "config.json" | |
| model_path = "model.safetensors" | |
| print(f"Loading configuration from {config_path}...") | |
| # Load the configuration | |
| with open(config_path, "r") as f: | |
| config_dict = json.load(f) | |
| print("Configuration loaded. Creating model config...") | |
| config = GPTConfig.from_dict(config_dict) | |
| print(f"Model config created: {config}") | |
| print(f"Loading model weights from {model_path}...") | |
| # Load the model weights | |
| tensors = load_file(model_path) | |
| print("Instantiating model...") | |
| # Instantiate the model with the loaded config | |
| model = GPT(config) | |
| print("Loading weights into model...") | |
| model.load_state_dict(tensors, strict=False) | |
| model.to(device) | |
| model.eval() | |
| print("Loading tokenizer...") | |
| # Load the tokenizer | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| print("Model and tokenizer loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise | |
| def generate_text(prompt, max_length=50, temperature=0.7, top_k=40): | |
| """Generate text based on the provided prompt.""" | |
| if model is None or tokenizer is None: | |
| load_model() | |
| # Tokenize the input text | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| # Generate text | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_k=top_k | |
| ) | |
| # Decode the output | |
| output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return output_text | |
| # Create the Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
| gr.Markdown("# Leap0 Language Model") | |
| gr.Markdown("A GPT-2 based model trained on the Tiny Stories dataset") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="once upon a time in the village of", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| max_length = gr.Slider( | |
| minimum=1, | |
| maximum=200, | |
| value=50, | |
| step=1, | |
| label="Max Length" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=40, | |
| step=1, | |
| label="Top K" | |
| ) | |
| generate_btn = gr.Button("Generate Text") | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="Generated Output", | |
| lines=10, | |
| placeholder="Your generated text will appear here..." | |
| ) | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt, max_length, temperature, top_k], | |
| outputs=output | |
| ) | |
| return demo | |
| # Load the model when the script is run | |
| load_model() | |
| # Create and launch the interface | |
| demo = create_interface() | |
| if __name__ == "__main__": | |
| demo.launch() |