import os import json import time import tarfile import stat import threading import subprocess from pathlib import Path from typing import List, Dict, Optional import requests import gradio as gr # ---------------------------- # Force UTF-8 everywhere # ---------------------------- os.environ.setdefault("PYTHONIOENCODING", "utf-8") os.environ.setdefault("LANG", "C.UTF-8") os.environ.setdefault("LC_ALL", "C.UTF-8") # ---------------------------- # Your model on HF # ---------------------------- HF_REPO = os.environ.get("HF_REPO", "staeiou/bartleby-qwen3-0.6b") HF_FILE = os.environ.get("HF_FILE", "bartleby-qwen3-0.6b.Q4_K_M.gguf") # <-- your filename # ---------------------------- # llama.cpp server settings # ---------------------------- HOST = os.environ.get("LLAMA_HOST", "127.0.0.1") PORT = int(os.environ.get("LLAMA_PORT", "8080")) BASE_URL = f"http://{HOST}:{PORT}" CTX_SIZE = int(os.environ.get("LLAMA_CTX", "4096")) # 2 concurrent chats on ~2 vCPU: N_THREADS = int(os.environ.get("LLAMA_THREADS", "1")) N_THREADS_BATCH = int(os.environ.get("LLAMA_THREADS_BATCH", str(N_THREADS))) PARALLEL = int(os.environ.get("LLAMA_PARALLEL", "2")) THREADS_HTTP = int(os.environ.get("LLAMA_THREADS_HTTP", "2")) BATCH_SIZE = int(os.environ.get("LLAMA_BATCH", "256")) UBATCH_SIZE = int(os.environ.get("LLAMA_UBATCH", "128")) SYSTEM_PROMPT_DEFAULT = os.environ.get("SYSTEM_PROMPT", "") # Prefer /data if present (persistent), else /tmp DATA_DIR = Path("/data") if Path("/data").exists() else Path("/tmp") HF_HOME = Path(os.environ.get("HF_HOME", str(DATA_DIR / "hf_home"))) os.environ["HF_HOME"] = str(HF_HOME) LLAMA_DIR = Path(os.environ.get("LLAMA_BIN_DIR", str(DATA_DIR / "llama_cpp_bin"))) LLAMA_DIR.mkdir(parents=True, exist_ok=True) # ---------------------------- # CSS # Key points: # - prevent body/page growth (no shrinking scrollbar thumb) # - fixed full-viewport flex layout # - DO NOT override Chatbot's internal scrolling; let Gradio autoscroll work # ---------------------------- CUSTOM_CSS = r""" footer { visibility: hidden; } html, body { height: 100%; margin: 0; overflow: hidden !important; /* prevent page scroll leak */ } .gradio-container { height: 100dvh !important; max-height: 100dvh !important; overflow: hidden !important; } /* Fullscreen layout */ #app_root { position: fixed; inset: 0; display: flex; flex-direction: column; overflow: hidden !important; } /* Transcript gets all remaining space */ #chat_wrap { flex: 1 1 auto; min-height: 0; /* critical for flex scroll layouts */ overflow: hidden !important; /* prevent outer scroll; Chatbot will scroll internally */ } /* Ensure the Chatbot component itself stretches */ #chat_wrap .gradio-chatbot, #chat_wrap .gr-chatbot, #chat_wrap [data-testid="chatbot"] { height: 100% !important; max-height: none !important; } /* Bottom input row */ #input_row { flex: 0 0 auto; padding: 6px 0 6px 0; } /* One-line textbox */ #msg_box textarea { min-height: 2.6em !important; max-height: 2.6em !important; height: 2.6em !important; line-height: 1.25 !important; overflow: hidden !important; resize: none !important; } /* Send button height matches */ #send_btn button { min-height: 2.6em !important; height: 2.6em !important; padding-top: 0.2em !important; padding-bottom: 0.2em !important; } /* Params below input; don't let it take over the screen when opened */ #params_bar { flex: 0 0 auto; } #params_bar .gr-accordion-content, #params_bar .accordion-content { max-height: 45dvh; overflow: auto; } @media (max-width: 768px) { .gradio-container { padding: 8px !important; } } @media (min-width: 769px) { .gradio-container { padding: 12px !important; } } """ # ---------------------------- # (Optional) Mobile focus guard JS # Keeps keyboard from popping back up mid-stream unless user tapped input recently. # If you don't need it anymore, set JS = "". # ---------------------------- FOCUS_GUARD_JS = r""" () => { const isMobile = /Mobi|Android|iPhone|iPad|iPod/i.test(navigator.userAgent); if (!isMobile) return; const inputSel = "#msg_box textarea"; const chatSel = "#chat_wrap"; let lastTouch = 0; const arm = () => { const input = document.querySelector(inputSel); const chat = document.querySelector(chatSel); if (!input || !chat) return; input.addEventListener("touchstart", () => { lastTouch = Date.now(); }, { passive: true }); const blurIfUnintended = () => { const recent = (Date.now() - lastTouch) < 600; if (!recent && document.activeElement === input) input.blur(); }; const mo = new MutationObserver(() => blurIfUnintended()); mo.observe(chat, { childList: true, subtree: true, characterData: true }); document.addEventListener("focusin", (e) => { if (e.target === input) blurIfUnintended(); }, true); }; arm(); setTimeout(arm, 500); setTimeout(arm, 1500); } """ # ---------------------------- # Server lifecycle globals # ---------------------------- _server_lock = threading.Lock() _server_proc: subprocess.Popen | None = None LLAMA_SERVER: Path | None = None SERVER_MODEL_ID: str | None = None def _make_executable(path: Path) -> None: st = os.stat(path) os.chmod(path, st.st_mode | stat.S_IEXEC) def _safe_extract_tar(tf: tarfile.TarFile, out_dir: Path) -> None: try: tf.extractall(path=out_dir, filter="data") # py3.12+ except TypeError: tf.extractall(path=out_dir) def _download_llama_cpp_release() -> Path: existing = list(LLAMA_DIR.rglob("llama-server")) for p in existing: if p.is_file(): _make_executable(p) return p asset_url = None try: rel = requests.get( "https://api.github.com/repos/ggml-org/llama.cpp/releases/latest", timeout=20, ).json() for a in rel.get("assets", []): name = a.get("name", "") if "bin-ubuntu-x64" in name and name.endswith(".tar.gz"): asset_url = a.get("browser_download_url") break except Exception: asset_url = None if not asset_url: asset_url = "https://github.com/ggml-org/llama.cpp/releases/latest/download/llama-bin-ubuntu-x64.tar.gz" tar_path = LLAMA_DIR / "llama-bin-ubuntu-x64.tar.gz" print(f"[app] Downloading llama.cpp release: {asset_url}", flush=True) with requests.get(asset_url, stream=True, timeout=180) as r: r.raise_for_status() with open(tar_path, "wb") as f: for chunk in r.iter_content(chunk_size=1024 * 1024): if chunk: f.write(chunk) print("[app] Extracting llama.cpp tarball...", flush=True) with tarfile.open(tar_path, "r:gz") as tf: _safe_extract_tar(tf, LLAMA_DIR) candidates = list(LLAMA_DIR.rglob("llama-server")) if not candidates: raise RuntimeError("Downloaded llama.cpp release but could not find llama-server binary.") server_bin = candidates[0] _make_executable(server_bin) print(f"[app] llama-server path: {server_bin}", flush=True) return server_bin def _wait_for_health(timeout_s: int = 180) -> None: deadline = time.time() + timeout_s last_err = None while time.time() < deadline: try: r = requests.get(f"{BASE_URL}/health", timeout=2) if r.status_code == 200: return last_err = f"health status {r.status_code}" except Exception as e: last_err = str(e) time.sleep(0.5) raise RuntimeError(f"llama-server not healthy in time. Last error: {last_err}") def ensure_server_started() -> None: global _server_proc, LLAMA_SERVER, SERVER_MODEL_ID with _server_lock: if _server_proc and _server_proc.poll() is None: return LLAMA_SERVER = _download_llama_cpp_release() HF_HOME.mkdir(parents=True, exist_ok=True) cmd = [ str(LLAMA_SERVER), "--host", HOST, "--port", str(PORT), "--no-webui", "--jinja", "--ctx-size", str(CTX_SIZE), "--threads", str(N_THREADS), "--threads-batch", str(N_THREADS_BATCH), "--threads-http", str(THREADS_HTTP), "--parallel", str(PARALLEL), "--cont-batching", "--batch-size", str(BATCH_SIZE), "--ubatch-size", str(UBATCH_SIZE), "-hf", HF_REPO, "--hf-file", HF_FILE, ] print("[app] Starting llama-server with:", flush=True) print(" " + " ".join(cmd), flush=True) env = os.environ.copy() env["PYTHONIOENCODING"] = "utf-8" env["LANG"] = env.get("LANG", "C.UTF-8") env["LC_ALL"] = env.get("LC_ALL", "C.UTF-8") # Inherit logs to container; avoids PIPE deadlock _server_proc = subprocess.Popen(cmd, stdout=None, stderr=None, env=env) _wait_for_health(timeout_s=180) try: j = requests.get(f"{BASE_URL}/v1/models", timeout=5).json() SERVER_MODEL_ID = j["data"][0]["id"] except Exception: SERVER_MODEL_ID = HF_REPO print(f"[app] llama-server healthy. model_id={SERVER_MODEL_ID}", flush=True) def stream_chat(messages, temperature: float, top_p: float, max_tokens: int): payload = { "model": SERVER_MODEL_ID or HF_REPO, "messages": messages, "temperature": float(temperature), "top_p": float(top_p), "max_tokens": int(max_tokens), "stream": True, } headers = { "Accept": "text/event-stream", "Content-Type": "application/json; charset=utf-8", } last_err: Optional[Exception] = None for _attempt in range(12): try: with requests.post( f"{BASE_URL}/v1/chat/completions", json=payload, stream=True, timeout=600, headers=headers, ) as r: if r.status_code != 200: body = r.text[:2000] raise requests.exceptions.HTTPError( f"{r.status_code} from llama-server: {body}", response=r, ) for raw in r.iter_lines(decode_unicode=False): if not raw: continue line = raw.decode("utf-8", errors="replace") if not line.startswith("data: "): continue data = line[len("data: "):].strip() if data == "[DONE]": return try: obj = json.loads(data) except Exception: continue delta = obj["choices"][0].get("delta") or {} tok = delta.get("content") if tok: yield tok return except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: last_err = e time.sleep(0.5) try: ensure_server_started() except Exception: pass if last_err: raise last_err # ---------------------------- # Chat handlers (messages format: list of {"role","content"}) # ---------------------------- ChatHistory = List[Dict[str, str]] def _truncate(s: str, n: int) -> str: s = s if isinstance(s, str) else str(s) return s if len(s) <= n else s[:n] def on_user_submit(user_text: str, history: ChatHistory): user_text = (user_text or "").strip() if not user_text: return "", history user_text = _truncate(user_text, 2000) history = history or [] history = history + [{"role": "user", "content": user_text}, {"role": "assistant", "content": ""}] return "", history def on_bot_respond(history: ChatHistory, system_message: str, max_tokens: int, temperature: float, top_p: float): ensure_server_started() history = history or [] if len(history) < 2 or history[-1].get("role") != "assistant": yield history return user_msg = history[-2].get("content", "") msgs = [] sys = (system_message or "").strip() if sys: msgs.append({"role": "system", "content": sys}) # Your rule: ONLY latest user prompt msgs.append({"role": "user", "content": user_msg}) out = "" for tok in stream_chat( msgs, temperature=float(temperature), top_p=float(top_p), max_tokens=int(max_tokens), ): out += tok history[-1]["content"] = out yield history # ---------------------------- # UI # ---------------------------- with gr.Blocks(title="BartlebyGPT", fill_height=True) as demo: with gr.Column(elem_id="app_root"): with gr.Column(elem_id="chat_wrap"): # Critical: set height to "100%" so it fills chat_wrap, and autoscroll=True chatbot = gr.Chatbot( value=[], show_label=False, autoscroll=True, height="100%", elem_id="chatbot", ) with gr.Row(elem_id="input_row"): msg = gr.Textbox( placeholder="What do you want?", show_label=False, lines=1, max_lines=1, autofocus=False, elem_id="msg_box", scale=10, ) send = gr.Button("Send", variant="primary", elem_id="send_btn", scale=1) with gr.Accordion("Params", open=False, elem_id="params_bar"): system_box = gr.Textbox(value=SYSTEM_PROMPT_DEFAULT, label="System message", lines=2) with gr.Row(): max_tokens = gr.Slider(1, 2048, value=1024, step=1, label="Max new tokens") temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Top-p") # Wire up msg.submit(on_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( on_bot_respond, [chatbot, system_box, max_tokens, temperature, top_p], [chatbot], ) send.click(on_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( on_bot_respond, [chatbot, system_box, max_tokens, temperature, top_p], [chatbot], ) demo.queue(default_concurrency_limit=2, max_size=256) # Warm start server try: ensure_server_started() except Exception as e: print("[app] llama-server eager start failed:", repr(e), flush=True) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", "7860")), css=CUSTOM_CSS, js=FOCUS_GUARD_JS, )