AIMan / app.py
PioTio's picture
Add tokenizer normalization retry in load_model
ef417e5 verified
import threading
import time
import os
from typing import List, Tuple, Optional
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
# Optional imports (best-effort features)
try:
from transformers import BitsAndBytesConfig
HAS_BNB = True
except Exception:
HAS_BNB = False
try:
from peft import PeftModel
HAS_PEFT = True
except Exception:
HAS_PEFT = False
# ---------------------------------------------------------------------------
# Config / defaults
# ---------------------------------------------------------------------------
DEFAULT_MODEL = "PioTio/nanbeige-4.1-aiman-merged"
CPU_DEMO_MODEL = "distilgpt2" # fast, small CPU-friendly fallback for demos
DEFAULT_SYSTEM_PROMPT = "You are a helpful, honest assistant. Answer succinctly unless asked otherwise."
# globals populated by load_model()
MODEL = None
TOKENIZER = None
MODEL_NAME = None
DEVICE = "cpu"
MODEL_LOCK = threading.Lock()
# flag: whether a model load is currently in progress (prevents requests)
MODEL_LOADING = False
# flag: whether the loaded tokenizer exposes a chat template helper
USE_CHAT_TEMPLATE = False
# ----------------------------- Utilities ---# ------------------------------
def _get_tok_vocab_size(tok: AutoTokenizer) -> Optional[int]:
try:
return int(getattr(tok, "vocab_size"))
except Exception:
try:
return int(tok.get_vocab_size())
except Exception:
return len(tok.get_vocab()) if hasattr(tok, "get_vocab") else None
def _diagnose_and_fix_tokenizer_model(tok: AutoTokenizer, mdl: AutoModelForCausalLM):
"""Fix common tokenizer<->model mismatches (SentencePiece piece-id edge-cases).
This mirrors the notebook fixes so Spaces will not hit `piece id out of range`.
"""
tok_vs = _get_tok_vocab_size(tok) or 0
try:
emb_rows = mdl.get_input_embeddings().weight.shape[0]
except Exception:
emb_rows = 0
special_ids = getattr(tok, "all_special_ids", []) or []
max_special_id = max(special_ids) if special_ids else 0
required = max(tok_vs, emb_rows, max_special_id + 1)
# update tokenizer.vocab_size if it's smaller than required
if getattr(tok, "vocab_size", 0) < required:
try:
tok.vocab_size = required
except Exception:
pass
# resize model embeddings if model is smaller
if emb_rows < required:
try:
mdl.resize_token_embeddings(required)
mdl.config.vocab_size = required
except Exception:
pass
# ensure pad token exists and ids/config align
if getattr(tok, "pad_token", None) is None:
tok.pad_token = getattr(tok, "eos_token", "[PAD]")
# Be defensive: different tokenizer backends expect different arg types
try:
tok.add_special_tokens({"pad_token": tok.pad_token})
except TypeError as e:
# try list form or add_tokens fallback
try:
tok.add_special_tokens([tok.pad_token])
except Exception:
try:
tok.add_tokens([tok.pad_token])
except Exception:
pass
except Exception:
pass
try:
pad_id = tok.convert_tokens_to_ids(tok.pad_token)
tok.pad_token_id = pad_id
mdl.config.pad_token_id = pad_id
except Exception:
pass
# Helper: detect Git-LFS pointer files and fetch real tokenizer.model from the Hub
def _is_lfs_pointer_file(path: str) -> bool:
try:
with open(path, "rb") as f:
start = f.read(128)
return b"git-lfs.github.com/spec/v1" in start
except Exception:
return False
def _download_tokenizer_model_from_hub(hf_repo: str, dest_path: str, hf_token: Optional[str] = None) -> bool:
"""Download tokenizer.model from HF Hub into dest_path. Returns True on success."""
try:
import urllib.request
url = f"https://huggingface.co/{hf_repo}/resolve/main/tokenizer.model"
req = urllib.request.Request(url, headers={"User-Agent": "spaces-nanbeige-chat/1.0"})
if hf_token:
req.add_header("Authorization", f"Bearer {hf_token}")
with urllib.request.urlopen(req, timeout=30) as r, open(dest_path + ".tmp", "wb") as out:
out.write(r.read())
os.replace(dest_path + ".tmp", dest_path)
return True
except Exception as e:
print("_download_tokenizer_model_from_hub failed:", e)
try:
if os.path.exists(dest_path + ".tmp"):
os.remove(dest_path + ".tmp")
except Exception:
pass
return False
def _ensure_local_tokenizer_model(repo_path: str, hf_token: Optional[str] = None) -> bool:
"""If tokenizer.model in repo_path is a Git-LFS pointer, try to download the real file from the Hub.
Tries to infer a Hub repo id from the local git remote; falls back to `PioTio/<dirname>` for Nanbeige folders.
"""
tm = os.path.join(repo_path, "tokenizer.model")
if not os.path.exists(tm):
return False
if not _is_lfs_pointer_file(tm):
return True
# try to get repo id from git remote origin
repo_id = None
try:
import subprocess
out = subprocess.check_output(["git", "-C", repo_path, "config", "--get", "remote.origin.url"], text=True).strip()
if out and "huggingface.co" in out:
# parse https://huggingface.co/owner/repo(.git)
parts = out.rstrip(".git").split("/")
repo_id = f"{parts[-2]}/{parts[-1]}"
except Exception:
repo_id = None
# fallback: guess owner for common Nanbeige folder names
if repo_id is None:
guessed = os.path.basename(repo_path)
if guessed.lower().startswith("nanbeige") or "nanbeige" in guessed.lower():
repo_id = f"PioTio/{guessed}"
if repo_id:
return _download_tokenizer_model_from_hub(repo_id, tm, hf_token=hf_token)
return False
# Helper: upload tokenizer files (from a local tokenizer dir) back to a Hub repo
def _upload_tokenizer_files_to_hub(repo_id: str, local_tokenizer_dir: str, hf_token: Optional[str] = None) -> bool:
"""Upload tokenizer files (tokenizer.model, tokenizer_config.json, tokenizer.json, special_tokens_map.json)
Returns True if at least one file was uploaded successfully.
"""
try:
from huggingface_hub import HfApi
api = HfApi()
candidates = [
"tokenizer.model",
"tokenizer_config.json",
"tokenizer.json",
"special_tokens_map.json",
"chat_template.jinja",
]
uploaded = 0
for fn in candidates:
p = os.path.join(local_tokenizer_dir, fn)
if not os.path.exists(p):
continue
try:
api.upload_file(
path_or_fileobj=p,
path_in_repo=fn,
repo_id=repo_id,
token=hf_token,
commit_message=f"Auto-fix tokenizer: {fn}",
)
print(f"_upload_tokenizer_files_to_hub: uploaded {fn} to {repo_id}")
uploaded += 1
except Exception as e:
print(f"_upload_tokenizer_files_to_hub: failed to upload {fn}: {e}")
return uploaded > 0
except Exception as e:
print("_upload_tokenizer_files_to_hub failed:", e)
return False
def _repair_and_upload_tokenizer(repo_id: str, hf_token: Optional[str] = None) -> bool:
"""Fetch the correct base tokenizer (Nanbeige4.1 if detected, otherwise DEFAULT_MODEL),
then upload tokenizer files to the target repo. Returns True on success.
"""
try:
base = "Nanbeige/Nanbeige4.1-3B" if "4.1" in repo_id.lower() else DEFAULT_MODEL
from transformers import AutoTokenizer
import tempfile, shutil
tmp = tempfile.mkdtemp(prefix="tokenizer_fix_")
tok = AutoTokenizer.from_pretrained(base, use_fast=False, trust_remote_code=True)
tok.save_pretrained(tmp)
ok = _upload_tokenizer_files_to_hub(repo_id, tmp, hf_token=hf_token)
shutil.rmtree(tmp)
return ok
except Exception as e:
print("_repair_and_upload_tokenizer failed:", e)
return False
def repair_tokenizer_on_hub(repo_id: str) -> str:
"""Public helper callable from the UI: attempts to upload a working base tokenizer to `repo_id`.
Requires HF_TOKEN in the environment with write access to the target repo.
"""
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
return "HF_TOKEN not set — cannot upload tokenizer to Hub. Add HF_TOKEN and retry."
try:
ok = _repair_and_upload_tokenizer(repo_id, hf_token=hf_token)
return "Uploaded tokenizer files to repo" if ok else "Repair attempt failed (see logs)"
except Exception as e:
return f"Repair failed: {e}"
# ----------------------------- Model loading -------------------------------
def _safe_model_from_pretrained(repo_id, *args, **kwargs):
"""Call AutoModelForCausalLM.from_pretrained but retry without `use_auth_token`
if the called class improperly forwards unexpected kwargs into __init__.
"""
try:
return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs)
except TypeError as e:
msg = str(e)
if "use_auth_token" in msg or "unexpected keyword argument" in msg:
# retry without auth-token kwargs (some remote `from_pretrained` may leak kwargs)
kwargs2 = dict(kwargs)
kwargs2.pop("use_auth_token", None)
kwargs2.pop("token", None)
print(f"_safe_model_from_pretrained: retrying without auth-token due to: {e}")
return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs2)
raise
def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str:
"""Load model + tokenizer from the Hub. Graceful fallbacks and HF-token support.
Changes made:
- prefer slow tokenizer (use_fast=False)
- accept HF token via env HF_TOKEN for private repos / higher rate limits
- fallback to base tokenizer (`PioTio/Nanbeige2.5`) when tokenizer files are missing
- pass auth token into from_pretrained calls where supported
"""
global MODEL, TOKENIZER, MODEL_NAME, DEVICE
with MODEL_LOCK:
if MODEL is not None and MODEL_NAME == repo_id and not force_reload:
return f"Model already loaded: {MODEL_NAME} (@ {DEVICE})"
# mark loading state so UI handlers can guard incoming requests
global MODEL_LOADING
MODEL_LOADING = True
print(f"Model load started: {repo_id}")
MODEL = None
TOKENIZER = None
MODEL_NAME = repo_id
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
hf_token = os.environ.get("HF_TOKEN")
# 1) Try to load tokenizer (slow tokenizer is required for Nanbeige family)
try:
TOKENIZER = AutoTokenizer.from_pretrained(
repo_id,
use_fast=False,
trust_remote_code=True,
use_auth_token=hf_token,
)
print(f"Tokenizer loaded from repo: {repo_id}")
# detect whether tokenizer supports the Nanbeige chat template API
try:
global USE_CHAT_TEMPLATE
USE_CHAT_TEMPLATE = hasattr(TOKENIZER, "apply_chat_template")
print(f"USE_CHAT_TEMPLATE={USE_CHAT_TEMPLATE}")
except Exception:
USE_CHAT_TEMPLATE = False
except Exception as e_tok:
print(f"Tokenizer load from {repo_id} failed: {e_tok}")
# specific fix: some tokenizers fail with 'Input must be a List...' when
# `special_tokens_map.json` contains dict entries instead of plain strings.
# Try an in-memory normalization + local retry before broader fallbacks/repairs.
if "Input must be a List" in str(e_tok) or "Input must be a List[Union[str, AddedToken]]" in str(e_tok):
try:
print('Detected tokenizer add-tokens type error; attempting in-place normalization and retry...')
# try to download tokenizer files and normalize special_tokens_map.json
try:
from huggingface_hub import hf_hub_download
import json, tempfile, shutil
tmp = tempfile.mkdtemp(prefix="tokfix_")
# files we need locally for AutoTokenizer
candidates = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "added_tokens.json"]
for fn in candidates:
try:
src = hf_hub_download(repo_id=repo_id, filename=fn, token=hf_token)
shutil.copy(src, tmp)
except Exception:
# ignore missing files — AutoTokenizer is tolerant
pass
# normalize special_tokens_map.json if present
stm = os.path.join(tmp, "special_tokens_map.json")
if os.path.exists(stm):
try:
with open(stm, "r", encoding="utf-8") as f:
stm_j = json.load(f)
changed = False
if "additional_special_tokens" in stm_j:
new = []
for it in stm_j["additional_special_tokens"]:
if isinstance(it, dict):
new.append(it.get("content") or it.get("token") or str(it))
changed = True
else:
new.append(it)
stm_j["additional_special_tokens"] = new
for k in ["bos_token", "eos_token", "pad_token", "unk_token"]:
if k in stm_j and isinstance(stm_j[k], dict):
stm_j[k] = stm_j[k].get("content", stm_j[k])
changed = True
if changed:
with open(stm, "w", encoding="utf-8") as f:
json.dump(stm_j, f, ensure_ascii=False, indent=2)
print('Normalized special_tokens_map.json in temp dir')
except Exception:
pass
# try loading tokenizer from the temporary normalized directory
TOKENIZER = AutoTokenizer.from_pretrained(tmp, use_fast=False, trust_remote_code=True)
print('Tokenizer reloaded from normalized temp copy')
shutil.rmtree(tmp)
except Exception as e_localnorm:
print('In-place normalization retry failed:', e_localnorm)
# fall through to the existing repair path below
# as a fallback, attempt to auto-repair the remote repo (if HF token available)
if hf_token:
print('Attempting repo-side auto-repair/upload from base tokenizer...')
_repair_and_upload_tokenizer(repo_id, hf_token=hf_token)
TOKENIZER = AutoTokenizer.from_pretrained(repo_id, use_fast=False, trust_remote_code=True)
print('Tokenizer reloaded after repo repair')
else:
# final fallback will be handled by the outer fallbacks below
raise RuntimeError('Normalization + auto-repair could not proceed (no HF_TOKEN)')
except Exception as e_retry:
print('Repair/retry failed:', e_retry)
return f"Tokenizer load failed: {e_retry}"
else:
# If a local repo was cloned without git-lfs, tokenizer.model may be a pointer file — try auto-fetch
try:
if os.path.isdir(repo_id) and _ensure_local_tokenizer_model(repo_id, hf_token=hf_token):
print(f"Found LFS pointer at {repo_id}/tokenizer.model — fetched real tokenizer.model; retrying tokenizer load...")
TOKENIZER = AutoTokenizer.from_pretrained(
repo_id,
use_fast=False,
trust_remote_code=True,
use_auth_token=hf_token,
)
print(f"Tokenizer loaded from local repo after fetching LFS: {repo_id}")
else:
# Local workspace fallback: use bundled Nanbeige4.1 tokenizer if available
local_fallback = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'Nanbeige4.1-3B'))
if os.path.isdir(local_fallback):
try:
print(f"Attempting local workspace tokenizer fallback: {local_fallback}")
TOKENIZER = AutoTokenizer.from_pretrained(local_fallback, use_fast=False, trust_remote_code=True)
print(f"Tokenizer loaded from local workspace: {local_fallback}")
except Exception as e_local:
print(f"Local tokenizer fallback failed: {e_local}")
raise e_local
else:
# Try known base tokenizer on the Hub (Nanbeige4.1 if repo looks like 4.1)
base = "Nanbeige/Nanbeige4.1-3B" if "4.1" in repo_id.lower() else "PioTio/Nanbeige2.5"
print(f"Falling back to base tokenizer: {base}")
TOKENIZER = AutoTokenizer.from_pretrained(base, use_fast=False, trust_remote_code=True, use_auth_token=hf_token)
# If HF token is available, attempt to auto-repair/upload tokenizer files to the target repo
if hf_token:
try:
uploaded = _repair_and_upload_tokenizer(repo_id, hf_token=hf_token)
print(f"Auto-repair attempt to {repo_id}: {'succeeded' if uploaded else 'no-change/failure'}")
except Exception as e_rep:
print(f"Auto-repair attempt failed: {e_rep}")
except Exception as e_base:
# last-resort: try fast tokenizer (may still fail or produce garbled output)
try:
print(f"All fallbacks failed: {e_base}. Trying generic AutoTokenizer as last resort...")
TOKENIZER = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True, use_auth_token=hf_token)
except Exception as e_final:
MODEL_LOADING = False
return f"Tokenizer load failed: {e_final}"
# 2) Load model (prefer 4-bit on GPU if available)
if DEVICE == "cuda" and HAS_BNB:
try:
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
MODEL = _safe_model_from_pretrained(
repo_id,
device_map="auto",
quantization_config=bnb_config,
trust_remote_code=True,
use_auth_token=hf_token,
)
MODEL.eval()
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
MODEL_LOADING = False
print(f"Model load finished (4-bit): {repo_id}")
return f"Loaded {repo_id} (4-bit, device_map=auto)"
except Exception as e:
print("bnb/4bit load failed - falling back:", e)
# 3) FP16 / CPU fallback
try:
if DEVICE == "cuda":
MODEL = _safe_model_from_pretrained(
repo_id,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
use_auth_token=hf_token,
)
else:
MODEL = _safe_model_from_pretrained(
repo_id,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
trust_remote_code=True,
use_auth_token=hf_token,
)
MODEL.to("cpu")
MODEL.eval()
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL)
MODEL_LOADING = False
print(f"Model load finished: {repo_id} (@{DEVICE})")
return f"Loaded {repo_id} (@{DEVICE})"
except Exception as e:
MODEL = None
TOKENIZER = None
# clear loading flag and provide a helpful diagnostic message
MODEL_LOADING = False
print(f"Model load failed: {repo_id} -> {e}")
return f"Model load failed: {e} (hint: check HF_TOKEN, repo contents and ensure tokenizer.model is present)"
# ----------------------------- Prompt building -----------------------------
ALPACA_TMPL = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{}\n\n"
"### Input:\n{}\n\n"
"### Response:\n"
)
def _normalize_history(raw_history) -> List[Tuple[str, str]]:
"""Accept either:
- List of (user_str, assistant_str) tuples (legacy Gradio Chatbot)
- List of dicts {role: "user"|"assistant"|"system", content: str} (new messages API)
and return a list of (user, assistant) pairs suitable for prompt construction.
Behavior: pairs each user message with the next assistant message (assistant may be "" if not present).
NOTE: For chat-first models (Nanbeige4.1) we prefer `tokenizer.apply_chat_template` later
so this function only normalizes the history shape.
"""
if not raw_history:
return []
# already in tuple form?
try:
# quick check: sequence of 2-tuples
if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in raw_history):
return [(str(u or ""), str(a or "")) for u, a in raw_history]
except Exception:
pass
# handle messages-as-dicts
pairs: List[Tuple[str, str]] = []
pending_user: Optional[str] = None
for item in raw_history:
if isinstance(item, dict):
role = item.get("role") or item.get("type") or "user"
content = item.get("content") or item.get("value") or ""
content = str(content or "")
if role.lower() == "system":
# system messages are ignored for pairing (but could be injected elsewhere)
continue
if role.lower() == "user":
# if there's already a pending user without assistant, flush it first
if pending_user is not None:
pairs.append((pending_user, ""))
pending_user = content
elif role.lower() == "assistant":
if pending_user is None:
# assistant without user -> pair with empty user
pairs.append(("", content))
else:
pairs.append((pending_user, content))
pending_user = None
else:
# unknown shape -> stringify and treat as user turn
s = str(item)
if pending_user is not None:
pairs.append((pending_user, ""))
pending_user = s
if pending_user is not None:
pairs.append((pending_user, ""))
return pairs
def build_prompt(history, user_input: str, system_prompt: str, max_history: int = 6) -> str:
# normalize incoming history (supports both tuple-list and messages dicts)
pairs = _normalize_history(history or [])
pairs = pairs[-max_history:]
# If tokenizer provides a chat-template helper (Nanbeige4.1), use it.
# This avoids instruction-format mismatches that produce garbled output.
try:
from __main__ import TOKENIZER # safe access to global TOKENIZER when available
except Exception:
TOKENIZER = None
if TOKENIZER is not None and hasattr(TOKENIZER, "apply_chat_template"):
# build messages list with optional system prompt first
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for u, a in pairs:
messages.append({"role": "user", "content": u})
if a:
messages.append({"role": "assistant", "content": a})
# current user turn
messages.append({"role": "user", "content": user_input})
# use tokenizer's chat template (returns the full prompt string)
try:
prompt = TOKENIZER.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
return prompt
except Exception:
# fall back to ALPACA format if anything goes wrong
pass
# Default / fallback: ALPACA-style instruction template
parts: List[str] = [f"System: {system_prompt}"]
for u, a in pairs:
# include previous turns as completed instruction/response pairs
parts.append(ALPACA_TMPL.format(u, "") + (a or ""))
# append current user input as the instruction to complete
parts.append(ALPACA_TMPL.format(user_input, ""))
return "\n\n".join(parts)
# ----------------------------- Generation ---------------------------------
def _generate_text(prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int) -> str:
global MODEL, TOKENIZER
if MODEL is None or TOKENIZER is None:
raise RuntimeError("Model is not loaded. Press 'Load model' first.")
# When using a chat-template prompt we must avoid adding special tokens again
add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True
input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
pad_token_id=TOKENIZER.eos_token_id or 0,
eos_token_id=TOKENIZER.eos_token_id or None,
)
outputs = MODEL.generate(**gen_kwargs)
# strip prompt from the generated output
gen_tokens = outputs[0][input_ids.shape[1] :]
text = TOKENIZER.decode(gen_tokens, skip_special_tokens=True)
return text.strip()
def _generate_stream(prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int):
"""Yield partial outputs while the model generates (uses TextIteratorStreamer)."""
global MODEL, TOKENIZER
if MODEL is None or TOKENIZER is None:
raise RuntimeError("Model is not loaded. Press 'Load model' first.")
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True
input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
pad_token_id=TOKENIZER.eos_token_id or 0,
eos_token_id=TOKENIZER.eos_token_id or None,
streamer=streamer,
)
thread = threading.Thread(target=MODEL.generate, kwargs=gen_kwargs)
thread.start()
out = ""
for piece in streamer:
out += piece
yield out
thread.join()
# ----------------------------- Gradio app handlers -------------------------
def submit_message(user_message: str, history, system_prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int, stream: bool, max_history: int, force_cpu: bool = False):
"""Accepts history in either tuple-list form or messages-dict form and
always returns `[(user, assistant), ...]` tuples for Gradio's Chatbot.
"""
raw_history = history or []
# Normalize incoming history to tuple pairs
pairs = _normalize_history(raw_history)
# Append current user turn (assistant reply empty until generated)
pairs.append((str(user_message or ""), ""))
# Guard: block generation while model is loading or not loaded
if MODEL_LOADING:
pairs[-1] = (user_message, "⚠️ Model is still loading — please wait and try again. Check 'Status' for progress.")
yield pairs, ""
return
if MODEL is None:
pairs[-1] = (user_message, "⚠️ Model is not loaded — click 'Load model' first.")
yield pairs, ""
return
prompt = build_prompt(pairs[:-1], user_message, system_prompt, max_history)
# If user is running the full Nanbeige model on CPU, warn and suggest options
if MODEL_NAME == DEFAULT_MODEL and DEVICE == "cpu" and not force_cpu:
warning = (
"⚠️ **Nanbeige is too large for CPU inference and will be extremely slow.**\n\n"
"Options:\n"
"- Enable GPU in Space settings (recommended)\n"
f"- Click **Load fast CPU demo ({CPU_DEMO_MODEL})** for a quick, low-cost demo\n"
"- Or check 'Force CPU generation' to proceed on CPU (not recommended)")
pairs[-1] = (user_message, warning)
yield pairs, ""
return
if stream:
# stream partial assistant outputs
for partial in _generate_stream(prompt, temperature, top_p, top_k, max_new_tokens):
pairs[-1] = (user_message, partial)
# return tuple-list form for Chatbot component
yield pairs, ""
return
try:
out = _generate_text(prompt, temperature, top_p, top_k, max_new_tokens)
except Exception as e:
pairs[-1] = (user_message, f"<Error during generation: {e}>")
return pairs, ""
pairs[-1] = (user_message, out)
yield pairs, ""
def clear_chat() -> List[Tuple[str, str]]:
return []
def regenerate(history, system_prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int, stream: bool, max_history: int, force_cpu: bool = False):
if not history:
return history, ""
pairs = _normalize_history(history)
# regenerate last assistant reply using the last user message
last_user = pairs[-1][0] if pairs else ""
return submit_message(last_user, pairs[:-1], system_prompt, temperature, top_p, top_k, max_new_tokens, stream, max_history, force_cpu)
def load_model_ui(repo: str):
status = load_model(repo, force_reload=True)
try:
suffix = " — chat-template detected" if USE_CHAT_TEMPLATE else ""
except NameError:
suffix = ""
# enable the Send button only when the model actually loaded
loaded = str(status).lower().startswith("loaded")
from gradio import update as gr_update
send_state = gr_update(interactive=loaded)
return status + suffix, send_state
def apply_lora_adapter(adapter_repo: str):
if not HAS_PEFT:
return "peft not installed in this environment. Add `peft` to requirements.txt to enable LoRA loading."
global MODEL
if MODEL is None:
return "Load base model first."
hf_token = os.environ.get("HF_TOKEN")
try:
# allow huggingface auth token for private adapters
MODEL = PeftModel.from_pretrained(MODEL, adapter_repo, use_auth_token=hf_token)
return f"Applied LoRA adapter from {adapter_repo}"
except Exception as e:
return f"Failed to apply adapter: {e} (hint: check adapter name and HF_TOKEN)"
# ----------------------------- Build UI -----------------------------------
with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo:
gr.Markdown("# 🦙 Nanbeige2.5 — Chat (Hugging Face Space)\nA lightweight, streaming chat UI with tokenizer/model sanity checks and optional LoRA support.")
with gr.Row():
model_input = gr.Textbox(value=DEFAULT_MODEL, label="Model repo (HF)", interactive=True)
load_btn = gr.Button("Load model")
repair_btn = gr.Button("Repair tokenizer on Hub")
model_demo_btn = gr.Button(f"Load fast CPU demo ({CPU_DEMO_MODEL})")
model_status = gr.Textbox(value="Model not loaded", label="Status", interactive=False)
with gr.Row():
system_prompt = gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, label="System prompt (applies to all turns)", lines=2)
chatbot = gr.Chatbot(label="Conversation")
state = gr.State([])
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Type your message and press Enter...", lines=2)
send = gr.Button("Send")
clear = gr.Button("Clear")
# canned quick-replies (populates the input box)
with gr.Row():
quick_hi = gr.Button("Hi")
quick_joke = gr.Button("Tell me a joke")
quick_help = gr.Button("What can you do?")
quick_qora = gr.Button("Explain QLoRA")
quick_hi.click(lambda: "Hi", outputs=txt)
quick_joke.click(lambda: "Tell me a joke", outputs=txt)
quick_help.click(lambda: "What can you do?", outputs=txt)
quick_qora.click(lambda: "Explain QLoRA", outputs=txt)
with gr.Row():
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.01, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k")
max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="Max new tokens")
with gr.Row():
stream_toggle = gr.Checkbox(value=True, label="Stream responses (recommended)")
force_cpu = gr.Checkbox(value=False, label="Force CPU generation (not recommended)")
max_history = gr.Slider(1, 12, value=6, step=1, label="Max history turns")
regen = gr.Button("Regenerate")
with gr.Row():
adapter_box = gr.Textbox(value="", label="Optional LoRA adapter repo (HF) — leave blank if none")
apply_adapter = gr.Button("Apply LoRA adapter")
# Events
load_btn.click(fn=load_model_ui, inputs=model_input, outputs=[model_status, send])
repair_btn.click(fn=repair_tokenizer_on_hub, inputs=model_input, outputs=model_status)
send.click(
fn=submit_message,
inputs=[txt, state, system_prompt, temperature, top_p, top_k, max_new_tokens, stream_toggle, max_history, force_cpu],
outputs=[chatbot, txt],
)
txt.submit(
fn=submit_message,
inputs=[txt, state, system_prompt, temperature, top_p, top_k, max_new_tokens, stream_toggle, max_history, force_cpu],
outputs=[chatbot, txt],
)
clear.click(fn=clear_chat, inputs=None, outputs=[chatbot, state])
regen.click(
fn=regenerate,
inputs=[state, system_prompt, temperature, top_p, top_k, max_new_tokens, stream_toggle, max_history, force_cpu],
outputs=[chatbot, txt],
)
apply_adapter.click(fn=apply_lora_adapter, inputs=[adapter_box], outputs=[model_status])
# auto-load default model in background (non-blocking)
def _bg_initial_load():
# run load_model in a background thread to warm up model on Space startup
def _worker():
res = load_model(DEFAULT_MODEL, force_reload=False)
try:
# update UI Send button when loaded
from gradio import update as gr_update
interactive = str(res).lower().startswith("loaded")
send.update(interactive=interactive)
except Exception:
pass
return res
t = threading.Thread(target=_worker, daemon=True)
t.start()
return "Loading model in background..."
# For local smoke tests you can skip automatic model loading by setting
# environment variable `SKIP_AUTOLOAD=1` so the UI starts without loading
# the large model into memory.
if os.environ.get("SKIP_AUTOLOAD", "0") == "1":
model_status.value = "Auto-load skipped (SKIP_AUTOLOAD=1)"
else:
model_status.value = _bg_initial_load()
# disable Send while background load is in progress
try:
send.update(interactive=False)
except Exception:
pass
# CPU warning / demo hint (visible in UI)
gr.Markdown("""
**⚠️ If this Space is running on CPU, `PioTio/Nanbeige2.5` will be extremely slow.**
- Enable GPU in Space Settings for real-time use.
- Or click **Load fast CPU demo (distilgpt2)** for an immediate, low-cost demo reply.
""")
# wire demo button
model_demo_btn.click(fn=lambda: load_model_ui(CPU_DEMO_MODEL), inputs=None, outputs=model_status)
gr.Markdown("---\n**Tips:** select GPU hardware for smoother streaming and enable 4-bit bitsandbytes by installing `bitsandbytes` in `requirements.txt`.")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")