Spaces:
Sleeping
Sleeping
| import torch | |
| from pathlib import Path | |
| import gradio as gr | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| # -------------------- DEVICE -------------------- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # -------------------- MODEL CONFIG -------------------- | |
| MODEL_NAME = "FlameF0X/i3-80m" | |
| LOCAL_SAFETENSORS = Path("model.safetensors") | |
| LOCAL_BIN = Path("pytorch_model.bin") | |
| VOCAB_JSON = Path("chunk_vocab_combined.json") | |
| # -------------------- LOAD VOCAB -------------------- | |
| with open(VOCAB_JSON, 'r') as f: | |
| vocab_data = json.load(f) | |
| VOCAB_SIZE = vocab_data["vocab_size"] | |
| # -------------------- IMPORT YOUR MODEL CLASS -------------------- | |
| from app_classes import i3Model, ChunkTokenizer | |
| tokenizer = ChunkTokenizer() | |
| tokenizer.load(VOCAB_JSON) | |
| model = i3Model( | |
| vocab_size=VOCAB_SIZE, | |
| d_model=512, | |
| n_heads=16, | |
| max_seq_len=256, | |
| d_state=32 | |
| ).to(DEVICE) | |
| # -------------------- LOAD WEIGHTS -------------------- | |
| try: | |
| if LOCAL_SAFETENSORS.exists(): | |
| from safetensors.torch import load_file | |
| state_dict = load_file(LOCAL_SAFETENSORS) | |
| model.load_state_dict(state_dict) | |
| print("β Loaded weights from local safetensors") | |
| elif LOCAL_BIN.exists(): | |
| state_dict = torch.load(LOCAL_BIN, map_location=DEVICE, weights_only=False) | |
| model.load_state_dict(state_dict) | |
| print("β Loaded weights from local .bin") | |
| else: | |
| print("β‘ Downloading model from HuggingFace...") | |
| bin_file = hf_hub_download(repo_id=MODEL_NAME, filename="pytorch_model.bin") | |
| state_dict = torch.load(bin_file, map_location=DEVICE, weights_only=False) | |
| model.load_state_dict(state_dict) | |
| print("β Loaded weights from HuggingFace") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model weights: {e}") | |
| model.eval() | |
| # -------------------- GENERATION FUNCTION -------------------- | |
| def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=40): | |
| if not prompt.strip(): | |
| yield "β οΈ Please enter a prompt to generate text." | |
| return | |
| try: | |
| idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE) | |
| # Use the streaming method from the model | |
| for out_idx in model.generate_stream(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k): | |
| # Decode the current sequence (cpu() is needed because tokens are on GPU) | |
| current_text = tokenizer.decode(out_idx[0].cpu()) | |
| yield current_text | |
| except Exception as e: | |
| yield f"β Generation error: {str(e)}" | |
| # -------------------- GRADIO UI -------------------- | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .param-card { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 1.5rem; | |
| border-radius: 12px; | |
| margin-bottom: 1rem; | |
| } | |
| """ | |
| # We remove 'css', 'head', and 'theme' arguments from Blocks() and inject css via gr.HTML instead | |
| # to ensure compatibility across older Gradio versions. | |
| with gr.Blocks() as demo: | |
| gr.HTML(f"<style>{custom_css}</style>") | |
| # Header | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| # π i3-80M Text Generation | |
| ### Powered by Mamba-based Architecture | |
| Generate creative text using the i3-80M language model with customizable parameters. | |
| """, | |
| elem_classes="main-header" | |
| ) | |
| # Main Generation Area | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="βοΈ Enter Your Prompt", | |
| placeholder="Once upon a time in a distant galaxy...", | |
| lines=4, | |
| max_lines=8 | |
| ) | |
| with gr.Accordion("βοΈ Generation Parameters", open=True): | |
| with gr.Row(): | |
| max_tokens_input = gr.Slider( | |
| 10, 500, | |
| value=100, | |
| step=10, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temp_input = gr.Slider( | |
| 0.1, 2.0, | |
| value=0.8, | |
| step=0.05, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more focused" | |
| ) | |
| topk_input = gr.Slider( | |
| 1, 100, | |
| value=40, | |
| step=1, | |
| label="Top-k Sampling", | |
| info="Number of top tokens to consider" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("π¨ Generate Text", variant="primary", size="lg") | |
| clear_btn = gr.ClearButton(components=[prompt_input], value="ποΈ Clear", size="lg") | |
| with gr.Column(scale=2): | |
| output_text = gr.Textbox( | |
| label="π Generated Output", | |
| lines=12, | |
| max_lines=20 | |
| # show_copy_button removed for compatibility | |
| ) | |
| # Examples Section | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["The future of artificial intelligence is", 150, 0.7, 50], | |
| ["In a world where technology and nature coexist", 200, 0.9, 40], | |
| ["The scientist discovered something remarkable", 120, 0.8, 45], | |
| ], | |
| inputs=[prompt_input, max_tokens_input, temp_input, topk_input], | |
| label="π‘ Try These Examples" | |
| ) | |
| # Developer Panel | |
| with gr.Accordion("π§ Developer Info", open=False): | |
| total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(f""" | |
| **Model Architecture:** | |
| - **Model:** i3-80M | |
| - **Device:** {DEVICE} | |
| - **Vocab Size:** {VOCAB_SIZE:,} | |
| - **Parameters:** {total_params:,} ({total_params/1e6:.2f}M) | |
| """) | |
| with gr.Column(): | |
| gr.Markdown(f""" | |
| **Configuration:** | |
| - **d_model:** 512 | |
| - **n_heads:** 16 | |
| - **max_seq_len:** 256 | |
| - **d_state:** 32 | |
| """) | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| <div style="text-align: center; color: #666;"> | |
| <p>Built with β€οΈ using Gradio | Model: FlameF0X/i3-80m</p> | |
| </div> | |
| """, | |
| ) | |
| # Connect UI | |
| generate_btn.click( | |
| generate_text, | |
| inputs=[prompt_input, max_tokens_input, temp_input, topk_input], | |
| outputs=[output_text] | |
| ) | |
| # -------------------- RUN -------------------- | |
| if __name__ == "__main__": | |
| # queue() is generally required for streaming to work correctly | |
| demo.queue().launch(share=False) |