import torch import gradio as gr from tokenizers import Tokenizer from huggingface_hub import hf_hub_download import os, sys # ── Download model artifacts from HF Hub ────────────────────────────────────── REPO = "IvmeLabs/Ivme-Conversate-22M-Base" tokenizer_path = hf_hub_download(repo_id=REPO, filename="ivme_tokenizer.json") model_path = hf_hub_download(repo_id=REPO, filename="ivme_base_ema.pt") model_py_path = hf_hub_download(repo_id=REPO, filename="model.py") # Put model.py on the path so we can import it sys.path.insert(0, os.path.dirname(model_py_path)) from model import IvmeConversate # noqa: E402 (dynamic import) # ── Load tokenizer & model ───────────────────────────────────────────────────── tokenizer = Tokenizer.from_file(tokenizer_path) device = "cuda" if torch.cuda.is_available() else "cpu" ckpt = torch.load(model_path, map_location=device, weights_only=False) cfg = ckpt["cfg"] cfg.attn_backend = "sdpa" model = IvmeConversate(cfg).to(device) model.load_state_dict(ckpt["model"]) model.eval() # ── Lottie throbber injection ────────────────────────────────────────────────── # Read the JSON file at startup and embed it inline — no /file= serving needed. _LOTTIE_JSON_PATH = os.path.join(os.path.dirname(__file__), "ivmeloading.json") with open(_LOTTIE_JSON_PATH, "r", encoding="utf-8") as _f: _LOTTIE_JSON_STR = _f.read() LOTTIE_HTML = f"""
""" # ── Inference ────────────────────────────────────────────────────────────────── def build_prompt(history: list[dict], system: str) -> str: """Format a chat history into the model's special-token prompt format.""" parts = [] if system: parts.append(f"<|system|>{system}<|eos|>") for msg in history: role = msg["role"] # "user" | "assistant" parts.append(f"<|{role}|>{msg['content']}<|eos|>") parts.append("<|assistant|>") return "".join(parts) def respond(message: str, history: list[dict], system_prompt: str, max_new_tokens: int, temperature: float, top_k: int, repetition_penalty: float): history = history + [{"role": "user", "content": message}] prompt = build_prompt(history, system_prompt) ids = torch.tensor( [tokenizer.encode(prompt).ids], device=device ) # Streaming via token-by-token generation generated = ids.clone() response_tokens: list[int] = [] with torch.no_grad(): for _ in range(max_new_tokens): logits = model(generated)[:, -1, :] # (1, vocab) # Repetition penalty if repetition_penalty != 1.0: for tok in set(generated[0].tolist()): logits[0, tok] /= repetition_penalty # Temperature + top-k sampling logits = logits / max(temperature, 1e-6) if top_k > 0: topk_vals, _ = torch.topk(logits, top_k) logits[logits < topk_vals[:, -1:]] = float("-inf") probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) eos_id = tokenizer.token_to_id("<|eos|>") if next_tok.item() == eos_id: break response_tokens.append(next_tok.item()) generated = torch.cat([generated, next_tok], dim=1) # Yield partial decode on every token yield tokenizer.decode(response_tokens) # ── UI ───────────────────────────────────────────────────────────────────────── CSS = """ /* Clean, readable chat UI */ body, .gradio-container { font-family: 'Inter', system-ui, sans-serif; } #component-0 { max-width: 780px; margin: 0 auto; padding: 16px; } .chatbot { border-radius: 12px; } footer { display: none !important; } """ with gr.Blocks(css=CSS, title="İvme-Conversate-22M") as demo: # Lottie throbber (invisible until generation starts) gr.HTML(LOTTIE_HTML) gr.Markdown( "## İvme-Conversate-22M-Base\n" "22M-parameter decoder-only model · base (not instruction-tuned) · " "1024-token context · [model card ↗](https://huggingface.co/IvmeLabs/Ivme-Conversate-22M-Base)" ) chatbot = gr.Chatbot( type="messages", height=480, show_label=False, avatar_images=(None, "https://cdn-uploads.huggingface.co/production/uploads/670562d6ac129959c16f84d4/Gi8oMz-Q8n2CImbtVyHOy.png"), ) with gr.Row(): msg_box = gr.Textbox( placeholder="Continue the prompt…", show_label=False, scale=8, container=False, ) send_btn = gr.Button("Send", scale=1, variant="primary") with gr.Accordion("Settings", open=False): system_prompt = gr.Textbox( label="System prompt", value="", placeholder="Optional system context (note: base model may ignore it)", ) with gr.Row(): max_tokens = gr.Slider(16, 512, value=200, step=8, label="Max new tokens") temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature") with gr.Row(): top_k = gr.Slider(0, 200, value=40, step=1, label="Top-k (0 = disabled)") rep_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty") # Wire up submit def user_turn(message, history): return "", history + [{"role": "user", "content": message}] def bot_turn(history, system_prompt, max_tokens, temperature, top_k, rep_penalty): # Last entry is the user message user_msg = history[-1]["content"] prior = history[:-1] history = history + [{"role": "assistant", "content": ""}] for partial in respond(user_msg, prior, system_prompt, int(max_tokens), temperature, int(top_k), rep_penalty): history[-1]["content"] = partial yield history msg_box.submit( user_turn, [msg_box, chatbot], [msg_box, chatbot], queue=False ).then( bot_turn, [chatbot, system_prompt, max_tokens, temperature, top_k, rep_penalty], chatbot, ) send_btn.click( user_turn, [msg_box, chatbot], [msg_box, chatbot], queue=False ).then( bot_turn, [chatbot, system_prompt, max_tokens, temperature, top_k, rep_penalty], chatbot, ) demo.launch()