bartlebyGPT / app.py
staeiou's picture
Update app.py
f8eb138 verified
import os
import json
import time
import tarfile
import stat
import threading
import subprocess
from pathlib import Path
from typing import List, Dict, Optional
import requests
import gradio as gr
# ----------------------------
# UTF-8 everywhere
# ----------------------------
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
os.environ.setdefault("LANG", "C.UTF-8")
os.environ.setdefault("LC_ALL", "C.UTF-8")
# ----------------------------
# Model on HF (GGUF)
# ----------------------------
HF_REPO = os.environ.get("HF_REPO", "staeiou/bartleby-qwen3-0.6b")
HF_FILE = os.environ.get("HF_FILE", "bartleby-qwen3-0.6b.Q4_K_M.gguf")
# ----------------------------
# llama.cpp server settings
# ----------------------------
HOST = os.environ.get("LLAMA_HOST", "127.0.0.1")
PORT = int(os.environ.get("LLAMA_PORT", "8080"))
BASE_URL = f"http://{HOST}:{PORT}"
# Context: keep reasonable on 2 vCPU; 4096 can be OK but raises KV cost.
# If you can, 2048 often feels snappier on CPU.
CTX_SIZE = int(os.environ.get("LLAMA_CTX", "512"))
# ---- CPU-tuned defaults for HF Spaces free tier (2 vCPU) ----
# Key changes vs your original:
# - parallel=1 (best latency + tokens/sec per user on tiny CPU)
# - threads=2 (use both vCPUs)
# - http threads=1 (avoid contention)
# - batching smaller (big batch can hurt latency on CPU)
N_THREADS = int(os.environ.get("LLAMA_THREADS", "2"))
N_THREADS_BATCH = int(os.environ.get("LLAMA_THREADS_BATCH", str(N_THREADS)))
PARALLEL = int(os.environ.get("LLAMA_PARALLEL", "1"))
THREADS_HTTP = int(os.environ.get("LLAMA_THREADS_HTTP", "1"))
BATCH_SIZE = int(os.environ.get("LLAMA_BATCH", "128"))
UBATCH_SIZE = int(os.environ.get("LLAMA_UBATCH", "64"))
# Optional knobs
# Pin model in RAM if allowed (may fail under strict memory limits; safe to try)
USE_MLOCK = os.environ.get("LLAMA_MLOCK", "1") == "1"
# Disable continuous batching for best single-user latency on 2 vCPU
USE_CONT_BATCHING = os.environ.get("LLAMA_CONT_BATCHING", "0") == "1"
SYSTEM_PROMPT_DEFAULT = os.environ.get("SYSTEM_PROMPT", "")
# Prefer /data if present (persistent), else /tmp
DATA_DIR = Path("/data") if Path("/data").exists() else Path("/tmp")
HF_HOME = Path(os.environ.get("HF_HOME", str(DATA_DIR / "hf_home")))
os.environ["HF_HOME"] = str(HF_HOME)
LLAMA_DIR = Path(os.environ.get("LLAMA_BIN_DIR", str(DATA_DIR / "llama_cpp_bin")))
LLAMA_DIR.mkdir(parents=True, exist_ok=True)
# ----------------------------
# CSS
# ----------------------------
CUSTOM_CSS = r"""
footer { visibility: hidden; }
html, body {
height: 100%;
margin: 0;
overflow: hidden !important;
}
.gradio-container {
height: 100dvh !important;
max-height: 100dvh !important;
overflow: hidden !important;
}
#app_root {
position: fixed;
inset: 0;
display: flex;
flex-direction: column;
overflow: hidden !important;
}
#chat_wrap {
flex: 1 1 auto;
min-height: 0;
overflow: hidden !important;
}
#chat_wrap .gradio-chatbot,
#chat_wrap .gr-chatbot,
#chat_wrap [data-testid="chatbot"] {
height: 100% !important;
max-height: none !important;
}
#input_row {
flex: 0 0 auto;
padding: 6px 0 6px 0;
}
#msg_box textarea {
min-height: 2.6em !important;
max-height: 2.6em !important;
height: 2.6em !important;
line-height: 1.25 !important;
overflow: hidden !important;
resize: none !important;
}
#send_btn button {
min-height: 2.6em !important;
height: 2.6em !important;
padding-top: 0.2em !important;
padding-bottom: 0.2em !important;
}
#params_bar {
flex: 0 0 auto;
}
#params_bar .gr-accordion-content,
#params_bar .accordion-content {
max-height: 45dvh;
overflow: auto;
}
@media (max-width: 768px) {
.gradio-container { padding: 8px !important; }
}
@media (min-width: 769px) {
.gradio-container { padding: 12px !important; }
}
"""
FOCUS_GUARD_JS = r"""
() => {
const isMobile = /Mobi|Android|iPhone|iPad|iPod/i.test(navigator.userAgent);
if (!isMobile) return;
const inputSel = "#msg_box textarea";
const chatSel = "#chat_wrap";
let lastTouch = 0;
const arm = () => {
const input = document.querySelector(inputSel);
const chat = document.querySelector(chatSel);
if (!input || !chat) return;
input.addEventListener("touchstart", () => { lastTouch = Date.now(); }, { passive: true });
const blurIfUnintended = () => {
const recent = (Date.now() - lastTouch) < 600;
if (!recent && document.activeElement === input) input.blur();
};
const mo = new MutationObserver(() => blurIfUnintended());
mo.observe(chat, { childList: true, subtree: true, characterData: true });
document.addEventListener("focusin", (e) => {
if (e.target === input) blurIfUnintended();
}, true);
};
arm();
setTimeout(arm, 500);
setTimeout(arm, 1500);
}
"""
# ----------------------------
# Server lifecycle globals
# ----------------------------
_server_lock = threading.Lock()
_server_proc: subprocess.Popen | None = None
LLAMA_SERVER: Path | None = None
SERVER_MODEL_ID: str | None = None
# Reuse TCP connections (lower overhead)
SESSION = requests.Session()
def _make_executable(path: Path) -> None:
st = os.stat(path)
os.chmod(path, st.st_mode | stat.S_IEXEC)
def _safe_extract_tar(tf: tarfile.TarFile, out_dir: Path) -> None:
try:
tf.extractall(path=out_dir, filter="data") # py3.12+
except TypeError:
tf.extractall(path=out_dir)
def _download_llama_cpp_release() -> Path:
existing = list(LLAMA_DIR.rglob("llama-server"))
for p in existing:
if p.is_file():
_make_executable(p)
return p
asset_url = None
try:
rel = SESSION.get(
"https://api.github.com/repos/ggml-org/llama.cpp/releases/latest",
timeout=20,
).json()
for a in rel.get("assets", []):
name = a.get("name", "")
if "bin-ubuntu-x64" in name and name.endswith(".tar.gz"):
asset_url = a.get("browser_download_url")
break
except Exception:
asset_url = None
if not asset_url:
asset_url = "https://github.com/ggml-org/llama.cpp/releases/latest/download/llama-bin-ubuntu-x64.tar.gz"
tar_path = LLAMA_DIR / "llama-bin-ubuntu-x64.tar.gz"
print(f"[app] Downloading llama.cpp release: {asset_url}", flush=True)
with SESSION.get(asset_url, stream=True, timeout=180) as r:
r.raise_for_status()
with open(tar_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
print("[app] Extracting llama.cpp tarball...", flush=True)
with tarfile.open(tar_path, "r:gz") as tf:
_safe_extract_tar(tf, LLAMA_DIR)
candidates = list(LLAMA_DIR.rglob("llama-server"))
if not candidates:
raise RuntimeError("Downloaded llama.cpp release but could not find llama-server binary.")
server_bin = candidates[0]
_make_executable(server_bin)
print(f"[app] llama-server path: {server_bin}", flush=True)
return server_bin
def _wait_for_health(timeout_s: int = 180) -> None:
deadline = time.time() + timeout_s
last_err = None
while time.time() < deadline:
try:
r = SESSION.get(f"{BASE_URL}/health", timeout=2)
if r.status_code == 200:
return
last_err = f"health status {r.status_code}"
except Exception as e:
last_err = str(e)
time.sleep(0.35)
raise RuntimeError(f"llama-server not healthy in time. Last error: {last_err}")
def _warmup() -> None:
# A tiny request to force lazy init/JIT-ish paths and caches.
try:
payload = {
"model": SERVER_MODEL_ID or HF_REPO,
"messages": [{"role": "user", "content": "hi"}],
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": 4,
"stream": False,
}
SESSION.post(f"{BASE_URL}/v1/chat/completions", json=payload, timeout=60)
except Exception:
pass
def ensure_server_started() -> None:
global _server_proc, LLAMA_SERVER, SERVER_MODEL_ID
with _server_lock:
if _server_proc and _server_proc.poll() is None:
return
LLAMA_SERVER = _download_llama_cpp_release()
HF_HOME.mkdir(parents=True, exist_ok=True)
cmd = [
str(LLAMA_SERVER),
"--host", HOST,
"--port", str(PORT),
"--no-webui",
"--jinja",
"--ctx-size", str(CTX_SIZE),
"--threads", str(N_THREADS),
"--threads-batch", str(N_THREADS_BATCH),
"--threads-http", str(THREADS_HTTP),
"--parallel", str(PARALLEL),
"--batch-size", str(BATCH_SIZE),
"--ubatch-size", str(UBATCH_SIZE),
"-hf", HF_REPO,
"--hf-file", HF_FILE,
]
# Latency-oriented defaults
if USE_MLOCK:
cmd.append("--mlock")
# Continuous batching is usually worse for single-user latency on 2 vCPU.
if USE_CONT_BATCHING:
cmd.append("--cont-batching")
print("[app] Starting llama-server with:", flush=True)
print(" " + " ".join(cmd), flush=True)
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
env["LANG"] = env.get("LANG", "C.UTF-8")
env["LC_ALL"] = env.get("LC_ALL", "C.UTF-8")
# Inherit logs to container; avoids PIPE deadlock
_server_proc = subprocess.Popen(cmd, stdout=None, stderr=None, env=env)
_wait_for_health(timeout_s=180)
try:
j = SESSION.get(f"{BASE_URL}/v1/models", timeout=5).json()
SERVER_MODEL_ID = j["data"][0]["id"]
except Exception:
SERVER_MODEL_ID = HF_REPO
print(f"[app] llama-server healthy. model_id={SERVER_MODEL_ID}", flush=True)
_warmup()
def stream_chat(messages, temperature: float, top_p: float, max_tokens: int):
payload = {
"model": SERVER_MODEL_ID or HF_REPO,
"messages": messages,
"temperature": float(temperature),
"top_p": float(top_p),
"max_tokens": int(max_tokens),
"stream": True,
}
headers = {
"Accept": "text/event-stream",
"Content-Type": "application/json; charset=utf-8",
"Connection": "keep-alive",
}
last_err: Optional[Exception] = None
for _attempt in range(10):
try:
with SESSION.post(
f"{BASE_URL}/v1/chat/completions",
json=payload,
stream=True,
timeout=600,
headers=headers,
) as r:
if r.status_code != 200:
body = r.text[:2000]
raise requests.exceptions.HTTPError(
f"{r.status_code} from llama-server: {body}",
response=r,
)
# Efficient-ish SSE parsing
for raw in r.iter_lines(decode_unicode=False, chunk_size=8192):
if not raw:
continue
line = raw.decode("utf-8", errors="replace")
if not line.startswith("data: "):
continue
data = line[6:].strip()
if data == "[DONE]":
return
try:
obj = json.loads(data)
except Exception:
continue
delta = obj["choices"][0].get("delta") or {}
tok = delta.get("content")
if tok:
yield tok
return
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
last_err = e
time.sleep(0.35)
try:
ensure_server_started()
except Exception:
pass
if last_err:
raise last_err
# ----------------------------
# Chat handlers
# messages format: list of {"role","content"}
# ----------------------------
ChatHistory = List[Dict[str, str]]
def _truncate(s: str, n: int) -> str:
s = s if isinstance(s, str) else str(s)
return s if len(s) <= n else s[:n]
def on_user_submit(user_text: str, history: ChatHistory):
user_text = (user_text or "").strip()
if not user_text:
return "", history
user_text = _truncate(user_text, 2000)
history = history or []
history = history + [{"role": "user", "content": user_text}, {"role": "assistant", "content": ""}]
return "", history
def on_bot_respond(history: ChatHistory, system_message: str, max_tokens: int, temperature: float, top_p: float):
ensure_server_started()
history = history or []
if len(history) < 2 or history[-1].get("role") != "assistant":
yield history
return
user_msg = history[-2].get("content", "")
msgs = []
sys = (system_message or "").strip()
if sys:
msgs.append({"role": "system", "content": sys})
# Your rule: ONLY latest user prompt (fastest)
msgs.append({"role": "user", "content": user_msg})
out = ""
for tok in stream_chat(
msgs,
temperature=float(temperature),
top_p=float(top_p),
max_tokens=int(max_tokens),
):
out += tok
history[-1]["content"] = out
yield history
# ----------------------------
# UI
# ----------------------------
with gr.Blocks(title="BartlebyGPT", fill_height=True) as demo:
with gr.Column(elem_id="app_root"):
with gr.Column(elem_id="chat_wrap"):
chatbot = gr.Chatbot(
value=[],
show_label=False,
autoscroll=True,
height="100%",
elem_id="chatbot",
)
with gr.Row(elem_id="input_row"):
msg = gr.Textbox(
placeholder="What do you want?",
show_label=False,
lines=1,
max_lines=1,
autofocus=False,
elem_id="msg_box",
scale=10,
)
send = gr.Button("Send", variant="primary", elem_id="send_btn", scale=1)
with gr.Accordion("Params", open=False, elem_id="params_bar"):
system_box = gr.Textbox(value=SYSTEM_PROMPT_DEFAULT, label="System message", lines=2)
with gr.Row():
max_tokens = gr.Slider(1, 512, value=256, step=1, label="Max new tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
# Wire up (queue=False for submit handler to keep UI snappy)
msg.submit(on_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
on_bot_respond,
[chatbot, system_box, max_tokens, temperature, top_p],
[chatbot],
)
send.click(on_user_submit, [msg, chatbot], [msg, chatbot], queue=False).then(
on_bot_respond,
[chatbot, system_box, max_tokens, temperature, top_p],
[chatbot],
)
# On 2 vCPU, concurrency > 1 usually makes everyone slower.
demo.queue(default_concurrency_limit=1, max_size=128)
# Warm start server (best effort)
try:
ensure_server_started()
except Exception as e:
print("[app] llama-server eager start failed:", repr(e), flush=True)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", "7860")),
css=CUSTOM_CSS,
js=FOCUS_GUARD_JS,
)