|
|
import torch |
|
|
from pathlib import Path |
|
|
import gradio as gr |
|
|
import json |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
MODEL_NAME = "FlameF0X/i3-80m" |
|
|
LOCAL_SAFETENSORS = Path("model.safetensors") |
|
|
LOCAL_BIN = Path("pytorch_model.bin") |
|
|
VOCAB_JSON = Path("chunk_vocab_combined.json") |
|
|
|
|
|
|
|
|
with open(VOCAB_JSON, 'r') as f: |
|
|
vocab_data = json.load(f) |
|
|
VOCAB_SIZE = vocab_data["vocab_size"] |
|
|
|
|
|
|
|
|
|
|
|
from app_classes import i3Model, ChunkTokenizer |
|
|
|
|
|
tokenizer = ChunkTokenizer() |
|
|
tokenizer.load(VOCAB_JSON) |
|
|
|
|
|
model = i3Model( |
|
|
vocab_size=VOCAB_SIZE, |
|
|
d_model=512, |
|
|
n_heads=16, |
|
|
max_seq_len=256, |
|
|
d_state=32 |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
try: |
|
|
if LOCAL_SAFETENSORS.exists(): |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(LOCAL_SAFETENSORS) |
|
|
model.load_state_dict(state_dict) |
|
|
print("β
Loaded weights from local safetensors") |
|
|
elif LOCAL_BIN.exists(): |
|
|
state_dict = torch.load(LOCAL_BIN, map_location=DEVICE, weights_only=False) |
|
|
model.load_state_dict(state_dict) |
|
|
print("β
Loaded weights from local .bin") |
|
|
else: |
|
|
|
|
|
print("β‘ Downloading model from HuggingFace...") |
|
|
bin_file = hf_hub_download(repo_id=MODEL_NAME, filename="pytorch_model.bin") |
|
|
state_dict = torch.load(bin_file, map_location=DEVICE, weights_only=False) |
|
|
model.load_state_dict(state_dict) |
|
|
print("β
Loaded weights from HuggingFace") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load model weights: {e}") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=40): |
|
|
idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE) |
|
|
out_idx = model.generate(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) |
|
|
return tokenizer.decode(out_idx[0].cpu()) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("### i3-80M Text Generation") |
|
|
|
|
|
with gr.Row(): |
|
|
prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") |
|
|
max_tokens_input = gr.Slider(10, 500, value=100, step=10, label="Max Tokens") |
|
|
temp_input = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature") |
|
|
topk_input = gr.Slider(1, 100, value=40, step=1, label="Top-k Sampling") |
|
|
|
|
|
output_text = gr.Textbox(label="Generated Text") |
|
|
|
|
|
generate_btn = gr.Button("Generate") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
generate_text, |
|
|
inputs=[prompt_input, max_tokens_input, temp_input, topk_input], |
|
|
outputs=[output_text] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Dev Panel: Model Info", open=False): |
|
|
gr.Markdown(f"**Device:** {DEVICE}") |
|
|
gr.Markdown(f"**Vocab size:** {VOCAB_SIZE}") |
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
gr.Markdown(f"**Total Parameters:** {total_params:,} ({total_params/1e6:.2f}M)") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|