Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import tarfile | |
| import stat | |
| import threading | |
| import subprocess | |
| from pathlib import Path | |
| from typing import List, Dict, Optional | |
| import requests | |
| import gradio as gr | |
| # ---------------------------- | |
| # UTF-8 everywhere | |
| # ---------------------------- | |
| os.environ.setdefault("PYTHONIOENCODING", "utf-8") | |
| os.environ.setdefault("LANG", "C.UTF-8") | |
| os.environ.setdefault("LC_ALL", "C.UTF-8") | |
| # ---------------------------- | |
| # Model on HF (GGUF) | |
| # ---------------------------- | |
| HF_REPO = os.environ.get("HF_REPO", "staeiou/bartleby-qwen3-0.6b") | |
| HF_FILE = os.environ.get("HF_FILE", "bartleby-qwen3-0.6b.Q4_K_M.gguf") | |
| # ---------------------------- | |
| # llama.cpp server settings | |
| # ---------------------------- | |
| HOST = os.environ.get("LLAMA_HOST", "127.0.0.1") | |
| PORT = int(os.environ.get("LLAMA_PORT", "8080")) | |
| BASE_URL = f"http://{HOST}:{PORT}" | |
| # Context: keep reasonable on 2 vCPU; 4096 can be OK but raises KV cost. | |
| # If you can, 2048 often feels snappier on CPU. | |
| CTX_SIZE = int(os.environ.get("LLAMA_CTX", "512")) | |
| # ---- CPU-tuned defaults for HF Spaces free tier (2 vCPU) ---- | |
| # Key changes vs your original: | |
| # - parallel=1 (best latency + tokens/sec per user on tiny CPU) | |
| # - threads=2 (use both vCPUs) | |
| # - http threads=1 (avoid contention) | |
| # - batching smaller (big batch can hurt latency on CPU) | |
| N_THREADS = int(os.environ.get("LLAMA_THREADS", "2")) | |
| N_THREADS_BATCH = int(os.environ.get("LLAMA_THREADS_BATCH", str(N_THREADS))) | |
| PARALLEL = int(os.environ.get("LLAMA_PARALLEL", "1")) | |
| THREADS_HTTP = int(os.environ.get("LLAMA_THREADS_HTTP", "1")) | |
| BATCH_SIZE = int(os.environ.get("LLAMA_BATCH", "128")) | |
| UBATCH_SIZE = int(os.environ.get("LLAMA_UBATCH", "64")) | |
| # Optional knobs | |
| # Pin model in RAM if allowed (may fail under strict memory limits; safe to try) | |
| USE_MLOCK = os.environ.get("LLAMA_MLOCK", "1") == "1" | |
| # Disable continuous batching for best single-user latency on 2 vCPU | |
| USE_CONT_BATCHING = os.environ.get("LLAMA_CONT_BATCHING", "0") == "1" | |
| SYSTEM_PROMPT_DEFAULT = os.environ.get("SYSTEM_PROMPT", "") | |
| # Prefer /data if present (persistent), else /tmp | |
| DATA_DIR = Path("/data") if Path("/data").exists() else Path("/tmp") | |
| HF_HOME = Path(os.environ.get("HF_HOME", str(DATA_DIR / "hf_home"))) | |
| os.environ["HF_HOME"] = str(HF_HOME) | |
| LLAMA_DIR = Path(os.environ.get("LLAMA_BIN_DIR", str(DATA_DIR / "llama_cpp_bin"))) | |
| LLAMA_DIR.mkdir(parents=True, exist_ok=True) | |
| # ---------------------------- | |
| # CSS | |
| # ---------------------------- | |
| CUSTOM_CSS = r""" | |
| footer { visibility: hidden; } | |
| html, body { | |
| height: 100%; | |
| margin: 0; | |
| overflow: hidden !important; | |
| } | |
| .gradio-container { | |
| height: 100dvh !important; | |
| max-height: 100dvh !important; | |
| overflow: hidden !important; | |
| } | |
| #app_root { | |
| position: fixed; | |
| inset: 0; | |
| display: flex; | |
| flex-direction: column; | |
| overflow: hidden !important; | |
| } | |
| #chat_wrap { | |
| flex: 1 1 auto; | |
| min-height: 0; | |
| overflow: hidden !important; | |
| } | |
| #chat_wrap .gradio-chatbot, | |
| #chat_wrap .gr-chatbot, | |
| #chat_wrap [data-testid="chatbot"] { | |
| height: 100% !important; | |
| max-height: none !important; | |
| } | |
| #input_row { | |
| flex: 0 0 auto; | |
| padding: 6px 0 6px 0; | |
| } | |
| #msg_box textarea { | |
| min-height: 2.6em !important; | |
| max-height: 2.6em !important; | |
| height: 2.6em !important; | |
| line-height: 1.25 !important; | |
| overflow: hidden !important; | |
| resize: none !important; | |
| } | |
| #send_btn button { | |
| min-height: 2.6em !important; | |
| height: 2.6em !important; | |
| padding-top: 0.2em !important; | |
| padding-bottom: 0.2em !important; | |
| } | |
| #params_bar { | |
| flex: 0 0 auto; | |
| } | |
| #params_bar .gr-accordion-content, | |
| #params_bar .accordion-content { | |
| max-height: 45dvh; | |
| overflow: auto; | |
| } | |
| @media (max-width: 768px) { | |
| .gradio-container { padding: 8px !important; } | |
| } | |
| @media (min-width: 769px) { | |
| .gradio-container { padding: 12px !important; } | |
| } | |
| """ | |
| FOCUS_GUARD_JS = r""" | |
| () => { | |
| const isMobile = /Mobi|Android|iPhone|iPad|iPod/i.test(navigator.userAgent); | |
| if (!isMobile) return; | |
| const inputSel = "#msg_box textarea"; | |
| const chatSel = "#chat_wrap"; | |
| let lastTouch = 0; | |
| const arm = () => { | |
| const input = document.querySelector(inputSel); | |
| const chat = document.querySelector(chatSel); | |
| if (!input || !chat) return; | |
| input.addEventListener("touchstart", () => { lastTouch = Date.now(); }, { passive: true }); | |
| const blurIfUnintended = () => { | |
| const recent = (Date.now() - lastTouch) < 600; | |
| if (!recent && document.activeElement === input) input.blur(); | |
| }; | |
| const mo = new MutationObserver(() => blurIfUnintended()); | |
| mo.observe(chat, { childList: true, subtree: true, characterData: true }); | |
| document.addEventListener("focusin", (e) => { | |
| if (e.target === input) blurIfUnintended(); | |
| }, true); | |
| }; | |
| arm(); | |
| setTimeout(arm, 500); | |
| setTimeout(arm, 1500); | |
| } | |
| """ | |
| # ---------------------------- | |
| # Server lifecycle globals | |
| # ---------------------------- | |
| _server_lock = threading.Lock() | |
| _server_proc: subprocess.Popen | None = None | |
| LLAMA_SERVER: Path | None = None | |
| SERVER_MODEL_ID: str | None = None | |
| # Reuse TCP connections (lower overhead) | |
| SESSION = requests.Session() | |
| def _make_executable(path: Path) -> None: | |
| st = os.stat(path) | |
| os.chmod(path, st.st_mode | stat.S_IEXEC) | |
| def _safe_extract_tar(tf: tarfile.TarFile, out_dir: Path) -> None: | |
| try: | |
| tf.extractall(path=out_dir, filter="data") # py3.12+ | |
| except TypeError: | |
| tf.extractall(path=out_dir) | |
| def _download_llama_cpp_release() -> Path: | |
| existing = list(LLAMA_DIR.rglob("llama-server")) | |
| for p in existing: | |
| if p.is_file(): | |
| _make_executable(p) | |
| return p | |
| asset_url = None | |
| try: | |
| rel = SESSION.get( | |
| "https://api.github.com/repos/ggml-org/llama.cpp/releases/latest", | |
| timeout=20, | |
| ).json() | |
| for a in rel.get("assets", []): | |
| name = a.get("name", "") | |
| if "bin-ubuntu-x64" in name and name.endswith(".tar.gz"): | |
| asset_url = a.get("browser_download_url") | |
| break | |
| except Exception: | |
| asset_url = None | |
| if not asset_url: | |
| asset_url = "https://github.com/ggml-org/llama.cpp/releases/latest/download/llama-bin-ubuntu-x64.tar.gz" | |
| tar_path = LLAMA_DIR / "llama-bin-ubuntu-x64.tar.gz" | |
| print(f"[app] Downloading llama.cpp release: {asset_url}", flush=True) | |
| with SESSION.get(asset_url, stream=True, timeout=180) as r: | |
| r.raise_for_status() | |
| with open(tar_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024 * 1024): | |
| if chunk: | |
| f.write(chunk) | |
| print("[app] Extracting llama.cpp tarball...", flush=True) | |
| with tarfile.open(tar_path, "r:gz") as tf: | |
| _safe_extract_tar(tf, LLAMA_DIR) | |
| candidates = list(LLAMA_DIR.rglob("llama-server")) | |
| if not candidates: | |
| raise RuntimeError("Downloaded llama.cpp release but could not find llama-server binary.") | |
| server_bin = candidates[0] | |
| _make_executable(server_bin) | |
| print(f"[app] llama-server path: {server_bin}", flush=True) | |
| return server_bin | |
| def _wait_for_health(timeout_s: int = 180) -> None: | |
| deadline = time.time() + timeout_s | |
| last_err = None | |
| while time.time() < deadline: | |
| try: | |
| r = SESSION.get(f"{BASE_URL}/health", timeout=2) | |
| if r.status_code == 200: | |
| return | |
| last_err = f"health status {r.status_code}" | |
| except Exception as e: | |
| last_err = str(e) | |
| time.sleep(0.35) | |
| raise RuntimeError(f"llama-server not healthy in time. Last error: {last_err}") | |
| def _warmup() -> None: | |
| # A tiny request to force lazy init/JIT-ish paths and caches. | |
| try: | |
| payload = { | |
| "model": SERVER_MODEL_ID or HF_REPO, | |
| "messages": [{"role": "user", "content": "hi"}], | |
| "temperature": 0.0, | |
| "top_p": 1.0, | |
| "max_tokens": 4, | |
| "stream": False, | |
| } | |
| SESSION.post(f"{BASE_URL}/v1/chat/completions", json=payload, timeout=60) | |
| except Exception: | |
| pass | |
| def ensure_server_started() -> None: | |
| global _server_proc, LLAMA_SERVER, SERVER_MODEL_ID | |
| with _server_lock: | |
| if _server_proc and _server_proc.poll() is None: | |
| return | |
| LLAMA_SERVER = _download_llama_cpp_release() | |
| HF_HOME.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| str(LLAMA_SERVER), | |
| "--host", HOST, | |
| "--port", str(PORT), | |
| "--no-webui", | |
| "--jinja", | |
| "--ctx-size", str(CTX_SIZE), | |
| "--threads", str(N_THREADS), | |
| "--threads-batch", str(N_THREADS_BATCH), | |
| "--threads-http", str(THREADS_HTTP), | |
| "--parallel", str(PARALLEL), | |
| "--batch-size", str(BATCH_SIZE), | |
| "--ubatch-size", str(UBATCH_SIZE), | |
| "-hf", HF_REPO, | |
| "--hf-file", HF_FILE, | |
| ] | |
| # Latency-oriented defaults | |
| if USE_MLOCK: | |
| cmd.append("--mlock") | |
| # Continuous batching is usually worse for single-user latency on 2 vCPU. | |
| if USE_CONT_BATCHING: | |
| cmd.append("--cont-batching") | |
| print("[app] Starting llama-server with:", flush=True) | |
| print(" " + " ".join(cmd), flush=True) | |
| env = os.environ.copy() | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| env["LANG"] = env.get("LANG", "C.UTF-8") | |
| env["LC_ALL"] = env.get("LC_ALL", "C.UTF-8") | |
| # Inherit logs to container; avoids PIPE deadlock | |
| _server_proc = subprocess.Popen(cmd, stdout=None, stderr=None, env=env) | |
| _wait_for_health(timeout_s=180) | |
| try: | |
| j = SESSION.get(f"{BASE_URL}/v1/models", timeout=5).json() | |
| SERVER_MODEL_ID = j["data"][0]["id"] | |
| except Exception: | |
| SERVER_MODEL_ID = HF_REPO | |
| print(f"[app] llama-server healthy. model_id={SERVER_MODEL_ID}", flush=True) | |
| _warmup() | |
| def stream_chat(messages, temperature: float, top_p: float, max_tokens: int): | |
| payload = { | |
| "model": SERVER_MODEL_ID or HF_REPO, | |
| "messages": messages, | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "max_tokens": int(max_tokens), | |
| "stream": True, | |
| } | |
| headers = { | |
| "Accept": "text/event-stream", | |
| "Content-Type": "application/json; charset=utf-8", | |
| "Connection": "keep-alive", | |
| } | |
| last_err: Optional[Exception] = None | |
| for _attempt in range(10): | |
| try: | |
| with SESSION.post( | |
| f"{BASE_URL}/v1/chat/completions", | |
| json=payload, | |
| stream=True, | |
| timeout=600, | |
| headers=headers, | |
| ) as r: | |
| if r.status_code != 200: | |
| body = r.text[:2000] | |
| raise requests.exceptions.HTTPError( | |
| f"{r.status_code} from llama-server: {body}", | |
| response=r, | |
| ) | |
| # Efficient-ish SSE parsing | |
| for raw in r.iter_lines(decode_unicode=False, chunk_size=8192): | |
| if not raw: | |
| continue | |
| line = raw.decode("utf-8", errors="replace") | |
| if not line.startswith("data: "): | |
| continue | |
| data = line[6:].strip() | |
| if data == "[DONE]": | |
| return | |
| try: | |
| obj = json.loads(data) | |
| except Exception: | |
| continue | |
| delta = obj["choices"][0].get("delta") or {} | |
| tok = delta.get("content") | |
| if tok: | |
| yield tok | |
| return | |
| except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: | |
| last_err = e | |
| time.sleep(0.35) | |
| try: | |
| ensure_server_started() | |
| except Exception: | |
| pass | |
| if last_err: | |
| raise last_err | |
| # ---------------------------- | |
| # Chat handlers | |
| # messages format: list of {"role","content"} | |
| # ---------------------------- | |
| ChatHistory = List[Dict[str, str]] | |
| def _truncate(s: str, n: int) -> str: | |
| s = s if isinstance(s, str) else str(s) | |
| return s if len(s) <= n else s[:n] | |
| def on_user_submit(user_text: str, history: ChatHistory): | |
| user_text = (user_text or "").strip() | |
| if not user_text: | |
| return "", history | |
| user_text = _truncate(user_text, 2000) | |
| history = history or [] | |
| history = history + [{"role": "user", "content": user_text}, {"role": "assistant", "content": ""}] | |
| return "", history | |
| def on_bot_respond(history: ChatHistory, system_message: str, max_tokens: int, temperature: float, top_p: float): | |
| ensure_server_started() | |
| history = history or [] | |
| if len(history) < 2 or history[-1].get("role") != "assistant": | |
| yield history | |
| return | |
| user_msg = history[-2].get("content", "") | |
| msgs = [] | |
| sys = (system_message or "").strip() | |
| if sys: | |
| msgs.append({"role": "system", "content": sys}) | |
| # Your rule: ONLY latest user prompt (fastest) | |
| msgs.append({"role": "user", "content": user_msg}) | |
| out = "" | |
| for tok in stream_chat( | |
| msgs, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| max_tokens=int(max_tokens), | |
| ): | |
| out += tok | |
| history[-1]["content"] = out | |
| yield history | |
| # ---------------------------- | |
| # UI | |
| # ---------------------------- | |
| with gr.Blocks(title="BartlebyGPT", fill_height=True) as demo: | |
| with gr.Column(elem_id="app_root"): | |
| with gr.Column(elem_id="chat_wrap"): | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| show_label=False, | |
| autoscroll=True, | |
| height="100%", | |
| elem_id="chatbot", | |
| ) | |
| with gr.Row(elem_id="input_row"): | |
| msg = gr.Textbox( | |
| placeholder="What do you want?", | |
| show_label=False, | |
| lines=1, | |
| max_lines=1, | |
| autofocus=False, | |
| elem_id="msg_box", | |
| scale=10, | |
| ) | |
| send = gr.Button("Send", variant="primary", elem_id="send_btn", scale=1) | |
| with gr.Accordion("Params", open=False, elem_id="params_bar"): | |
| system_box = gr.Textbox(value=SYSTEM_PROMPT_DEFAULT, label="System message", lines=2) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(1, 512, value=256, step=1, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| # Wire up (queue=False for submit handler to keep UI snappy) | |
| msg.submit(on_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| on_bot_respond, | |
| [chatbot, system_box, max_tokens, temperature, top_p], | |
| [chatbot], | |
| ) | |
| send.click(on_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| on_bot_respond, | |
| [chatbot, system_box, max_tokens, temperature, top_p], | |
| [chatbot], | |
| ) | |
| # On 2 vCPU, concurrency > 1 usually makes everyone slower. | |
| demo.queue(default_concurrency_limit=1, max_size=128) | |
| # Warm start server (best effort) | |
| try: | |
| ensure_server_started() | |
| except Exception as e: | |
| print("[app] llama-server eager start failed:", repr(e), flush=True) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", "7860")), | |
| css=CUSTOM_CSS, | |
| js=FOCUS_GUARD_JS, | |
| ) | |