import torch from pathlib import Path import gradio as gr import json from huggingface_hub import hf_hub_download # -------------------- DEVICE -------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # -------------------- MODEL CONFIG -------------------- MODEL_NAME = "FlameF0X/i3-80m" # HuggingFace repo name LOCAL_SAFETENSORS = Path("model.safetensors") LOCAL_BIN = Path("pytorch_model.bin") VOCAB_JSON = Path("chunk_vocab_combined.json") # -------------------- LOAD VOCAB -------------------- with open(VOCAB_JSON, 'r') as f: vocab_data = json.load(f) VOCAB_SIZE = vocab_data["vocab_size"] # -------------------- IMPORT YOUR MODEL CLASS -------------------- # Make sure i3Model is in the same folder or installed as a package from app_classes import i3Model, ChunkTokenizer tokenizer = ChunkTokenizer() tokenizer.load(VOCAB_JSON) model = i3Model( vocab_size=VOCAB_SIZE, d_model=512, n_heads=16, max_seq_len=256, d_state=32 ).to(DEVICE) # -------------------- LOAD WEIGHTS -------------------- try: if LOCAL_SAFETENSORS.exists(): from safetensors.torch import load_file state_dict = load_file(LOCAL_SAFETENSORS) model.load_state_dict(state_dict) print("✅ Loaded weights from local safetensors") elif LOCAL_BIN.exists(): state_dict = torch.load(LOCAL_BIN, map_location=DEVICE, weights_only=False) model.load_state_dict(state_dict) print("✅ Loaded weights from local .bin") else: # HuggingFace fallback print("⚡ Downloading model from HuggingFace...") bin_file = hf_hub_download(repo_id=MODEL_NAME, filename="pytorch_model.bin") state_dict = torch.load(bin_file, map_location=DEVICE, weights_only=False) model.load_state_dict(state_dict) print("✅ Loaded weights from HuggingFace") except Exception as e: raise RuntimeError(f"Failed to load model weights: {e}") model.eval() # -------------------- GENERATION FUNCTION -------------------- def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=40): idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE) out_idx = model.generate(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) return tokenizer.decode(out_idx[0].cpu()) # -------------------- GRADIO UI -------------------- with gr.Blocks() as demo: gr.Markdown("### i3-80M Text Generation") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") max_tokens_input = gr.Slider(10, 500, value=100, step=10, label="Max Tokens") temp_input = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature") topk_input = gr.Slider(1, 100, value=40, step=1, label="Top-k Sampling") output_text = gr.Textbox(label="Generated Text") generate_btn = gr.Button("Generate") # Connect UI generate_btn.click( generate_text, inputs=[prompt_input, max_tokens_input, temp_input, topk_input], outputs=[output_text] ) # Developer Panel (shows model info) with gr.Accordion("Dev Panel: Model Info", open=False): gr.Markdown(f"**Device:** {DEVICE}") gr.Markdown(f"**Vocab size:** {VOCAB_SIZE}") total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) gr.Markdown(f"**Total Parameters:** {total_params:,} ({total_params/1e6:.2f}M)") # -------------------- RUN -------------------- if __name__ == "__main__": demo.launch()