Spaces:
Sleeping
Sleeping
| """ | |
| Tiny-LLM Demo - Text Generation with a 54M Parameter Model | |
| This model was trained from scratch on Wikipedia data. | |
| """ | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from model import TinyLLM, MODEL_CONFIG | |
| # Model configuration | |
| MODEL_ID = "jonmabe/tiny-llm-54m" | |
| MODEL_FILENAME = "final_model.pt" | |
| # Try to use transformers tokenizer, fall back to simple tokenizer | |
| try: | |
| from transformers import AutoTokenizer | |
| # Try to load from model repo, fall back to GPT-2 tokenizer | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| except: | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| USE_HF_TOKENIZER = True | |
| except Exception as e: | |
| print(f"Could not load HuggingFace tokenizer: {e}") | |
| USE_HF_TOKENIZER = False | |
| tokenizer = None | |
| # Load model | |
| print("Downloading model...") | |
| model_path = hf_hub_download(repo_id=MODEL_ID, filename=MODEL_FILENAME) | |
| print(f"Model downloaded to {model_path}") | |
| print("Loading model...") | |
| checkpoint = torch.load(model_path, map_location="cpu", weights_only=False) | |
| # Get config from checkpoint if available | |
| if "config" in checkpoint and isinstance(checkpoint["config"], dict): | |
| config = checkpoint["config"] | |
| if "model" in config: | |
| config = config["model"] | |
| else: | |
| config = MODEL_CONFIG | |
| # Initialize model | |
| model = TinyLLM(config) | |
| # Load weights | |
| if "model_state_dict" in checkpoint: | |
| state_dict = checkpoint["model_state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| if missing: | |
| print(f"Warning: Missing keys: {missing[:5]}...") | |
| if unexpected: | |
| print(f"Warning: Unexpected keys: {unexpected[:5]}...") | |
| # Move to device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Model loaded on {device} with {total_params:,} parameters") | |
| def generate_text( | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.8, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.1, | |
| ) -> str: | |
| """Generate text continuation from a prompt.""" | |
| if not prompt.strip(): | |
| return "Please enter a prompt to generate text." | |
| # Tokenize | |
| if USE_HF_TOKENIZER and tokenizer is not None: | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| eos_token_id = tokenizer.eos_token_id | |
| else: | |
| # Simple fallback - won't work well but better than crashing | |
| return "Tokenizer not available. Please ensure transformers is installed." | |
| # Generate | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| eos_token_id=eos_token_id, | |
| ) | |
| # Decode | |
| if USE_HF_TOKENIZER and tokenizer is not None: | |
| generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| else: | |
| generated_text = "Decoding not available." | |
| return generated_text | |
| # Example prompts | |
| EXAMPLES = [ | |
| ["The history of artificial intelligence began"], | |
| ["In the year 2050, humanity"], | |
| ["The most important scientific discovery was"], | |
| ["Once upon a time, in a kingdom far away"], | |
| ["The universe is vast and"], | |
| ["Climate change affects"], | |
| ["The theory of relativity states that"], | |
| ["In ancient Rome,"], | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Tiny-LLM Text Generator") as demo: | |
| gr.Markdown(""" | |
| # 🤖 Tiny-LLM Text Generator | |
| A **54 million parameter** language model trained **from scratch** on Wikipedia. | |
| This demonstrates that meaningful language models can be trained on consumer hardware! | |
| ### Architecture | |
| - **Parameters**: 54.93M | |
| - **Layers**: 12 | |
| - **Hidden Size**: 512 | |
| - **Attention Heads**: 8 | |
| - **Position Encoding**: RoPE | |
| - **Normalization**: RMSNorm | |
| - **Activation**: SwiGLU | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| value="The history of artificial intelligence began" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| max_tokens = gr.Slider( | |
| minimum=10, | |
| maximum=256, | |
| value=100, | |
| step=10, | |
| label="Max New Tokens", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more random" | |
| ) | |
| with gr.Column(): | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p (Nucleus Sampling)", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=5, | |
| label="Top-k", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.05, | |
| label="Repetition Penalty", | |
| info="Higher = less repetition" | |
| ) | |
| generate_btn = gr.Button("✨ Generate", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=15, | |
| interactive=False, | |
| ) | |
| gr.Markdown("### 📝 Example Prompts") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=prompt_input, | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt_input, max_tokens, temperature, top_p, top_k, repetition_penalty], | |
| outputs=output_text, | |
| ) | |
| prompt_input.submit( | |
| fn=generate_text, | |
| inputs=[prompt_input, max_tokens, temperature, top_p, top_k, repetition_penalty], | |
| outputs=output_text, | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About This Model | |
| **Model**: [jonmabe/tiny-llm-54m](https://huggingface.co/jonmabe/tiny-llm-54m) | |
| This is a decoder-only transformer trained from scratch on Wikipedia text. | |
| It demonstrates that meaningful language models can be trained on consumer hardware | |
| with modest compute budgets (~3 hours on an RTX 5090). | |
| #### Training Details | |
| - **Training Steps**: 50,000 | |
| - **Tokens**: ~100M | |
| - **Hardware**: NVIDIA RTX 5090 (32GB) | |
| - **Training Time**: ~3 hours | |
| #### Limitations | |
| - Small model size limits knowledge and capabilities | |
| - Trained only on Wikipedia - limited domain coverage | |
| - May generate factually incorrect information | |
| - Not instruction-tuned | |
| #### Intended Use | |
| - Educational: Understanding transformer training | |
| - Experimental: Testing fine-tuning approaches | |
| - Research: Lightweight model for NLP experiments | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |