ereniko's picture
Update app.py
345d16b verified
import torch
import gradio as gr
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import os, sys
# ── Download model artifacts from HF Hub ──────────────────────────────────────
REPO = "IvmeLabs/Ivme-Conversate-22M-Base"
tokenizer_path = hf_hub_download(repo_id=REPO, filename="ivme_tokenizer.json")
model_path = hf_hub_download(repo_id=REPO, filename="ivme_base_ema.pt")
model_py_path = hf_hub_download(repo_id=REPO, filename="model.py")
# Put model.py on the path so we can import it
sys.path.insert(0, os.path.dirname(model_py_path))
from model import IvmeConversate # noqa: E402 (dynamic import)
# ── Load tokenizer & model ─────────────────────────────────────────────────────
tokenizer = Tokenizer.from_file(tokenizer_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(model_path, map_location=device, weights_only=False)
cfg = ckpt["cfg"]
cfg.attn_backend = "sdpa"
model = IvmeConversate(cfg).to(device)
model.load_state_dict(ckpt["model"])
model.eval()
# ── Lottie throbber injection ──────────────────────────────────────────────────
# Read the JSON file at startup and embed it inline — no /file= serving needed.
_LOTTIE_JSON_PATH = os.path.join(os.path.dirname(__file__), "ivmeloading.json")
with open(_LOTTIE_JSON_PATH, "r", encoding="utf-8") as _f:
_LOTTIE_JSON_STR = _f.read()
LOTTIE_HTML = f"""
<!-- lottie-web (bodymovin) from cdnjs — stable, no web-component registration race -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/lottie-web/5.12.2/lottie.min.js"></script>
<style>
/* Hide the default generating indicator */
.generating > span,
.message.bot.generating .dot-flashing,
.message.bot.generating span[class*='dots'] {{
visibility: hidden !important;
}}
#ivme-throbber {{
display: none;
position: fixed;
bottom: 88px;
left: 50%;
transform: translateX(-50%);
width: 72px;
height: 72px;
z-index: 9999;
pointer-events: none;
}}
</style>
<div id="ivme-throbber"></div>
<script>
(function () {{
// Inline animation data — no network request needed
var animationData = {_LOTTIE_JSON_STR};
var anim = null;
var container = document.getElementById('ivme-throbber');
function initLottie() {{
if (anim || !container) return;
anim = lottie.loadAnimation({{
container: container,
renderer: 'svg',
loop: true,
autoplay: false,
animationData: animationData,
}});
}}
function setVisible(show) {{
if (!container) return;
if (show) {{
container.style.display = 'block';
if (!anim) initLottie();
else anim.play();
}} else {{
container.style.display = 'none';
if (anim) anim.stop();
}}
}}
var obs = new MutationObserver(function () {{
setVisible(!!document.querySelector('.generating'));
}});
function startObserver() {{
initLottie();
var root = document.querySelector('gradio-app') || document.body;
obs.observe(root, {{
subtree: true,
childList: true,
attributes: true,
attributeFilter: ['class'],
}});
}}
if (document.readyState === 'loading') {{
document.addEventListener('DOMContentLoaded', startObserver);
}} else {{
startObserver();
}}
}})();
</script>
"""
# ── Inference ──────────────────────────────────────────────────────────────────
def build_prompt(history: list[dict], system: str) -> str:
"""Format a chat history into the model's special-token prompt format."""
parts = []
if system:
parts.append(f"<|system|>{system}<|eos|>")
for msg in history:
role = msg["role"] # "user" | "assistant"
parts.append(f"<|{role}|>{msg['content']}<|eos|>")
parts.append("<|assistant|>")
return "".join(parts)
def respond(message: str, history: list[dict], system_prompt: str,
max_new_tokens: int, temperature: float, top_k: int,
repetition_penalty: float):
history = history + [{"role": "user", "content": message}]
prompt = build_prompt(history, system_prompt)
ids = torch.tensor(
[tokenizer.encode(prompt).ids], device=device
)
# Streaming via token-by-token generation
generated = ids.clone()
response_tokens: list[int] = []
with torch.no_grad():
for _ in range(max_new_tokens):
logits = model(generated)[:, -1, :] # (1, vocab)
# Repetition penalty
if repetition_penalty != 1.0:
for tok in set(generated[0].tolist()):
logits[0, tok] /= repetition_penalty
# Temperature + top-k sampling
logits = logits / max(temperature, 1e-6)
if top_k > 0:
topk_vals, _ = torch.topk(logits, top_k)
logits[logits < topk_vals[:, -1:]] = float("-inf")
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
eos_id = tokenizer.token_to_id("<|eos|>")
if next_tok.item() == eos_id:
break
response_tokens.append(next_tok.item())
generated = torch.cat([generated, next_tok], dim=1)
# Yield partial decode on every token
yield tokenizer.decode(response_tokens)
# ── UI ─────────────────────────────────────────────────────────────────────────
CSS = """
/* Clean, readable chat UI */
body, .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
#component-0 { max-width: 780px; margin: 0 auto; padding: 16px; }
.chatbot { border-radius: 12px; }
footer { display: none !important; }
"""
with gr.Blocks(css=CSS, title="İvme-Conversate-22M") as demo:
# Lottie throbber (invisible until generation starts)
gr.HTML(LOTTIE_HTML)
gr.Markdown(
"## İvme-Conversate-22M-Base\n"
"22M-parameter decoder-only model · base (not instruction-tuned) · "
"1024-token context · [model card ↗](https://huggingface.co/IvmeLabs/Ivme-Conversate-22M-Base)"
)
chatbot = gr.Chatbot(
type="messages",
height=480,
show_label=False,
avatar_images=(None, "https://cdn-uploads.huggingface.co/production/uploads/670562d6ac129959c16f84d4/Gi8oMz-Q8n2CImbtVyHOy.png"),
)
with gr.Row():
msg_box = gr.Textbox(
placeholder="Continue the prompt…",
show_label=False,
scale=8,
container=False,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
with gr.Accordion("Settings", open=False):
system_prompt = gr.Textbox(
label="System prompt",
value="",
placeholder="Optional system context (note: base model may ignore it)",
)
with gr.Row():
max_tokens = gr.Slider(16, 512, value=200, step=8, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature")
with gr.Row():
top_k = gr.Slider(0, 200, value=40, step=1, label="Top-k (0 = disabled)")
rep_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition penalty")
# Wire up submit
def user_turn(message, history):
return "", history + [{"role": "user", "content": message}]
def bot_turn(history, system_prompt, max_tokens, temperature, top_k, rep_penalty):
# Last entry is the user message
user_msg = history[-1]["content"]
prior = history[:-1]
history = history + [{"role": "assistant", "content": ""}]
for partial in respond(user_msg, prior, system_prompt,
int(max_tokens), temperature, int(top_k), rep_penalty):
history[-1]["content"] = partial
yield history
msg_box.submit(
user_turn, [msg_box, chatbot], [msg_box, chatbot], queue=False
).then(
bot_turn,
[chatbot, system_prompt, max_tokens, temperature, top_k, rep_penalty],
chatbot,
)
send_btn.click(
user_turn, [msg_box, chatbot], [msg_box, chatbot], queue=False
).then(
bot_turn,
[chatbot, system_prompt, max_tokens, temperature, top_k, rep_penalty],
chatbot,
)
demo.launch()