|
|
import threading |
|
|
import time |
|
|
import os |
|
|
from typing import List, Tuple, Optional |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import BitsAndBytesConfig |
|
|
HAS_BNB = True |
|
|
except Exception: |
|
|
HAS_BNB = False |
|
|
|
|
|
try: |
|
|
from peft import PeftModel |
|
|
HAS_PEFT = True |
|
|
except Exception: |
|
|
HAS_PEFT = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MODEL = "PioTio/nanbeige-4.1-aiman-merged" |
|
|
CPU_DEMO_MODEL = "distilgpt2" |
|
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful, honest assistant. Answer succinctly unless asked otherwise." |
|
|
|
|
|
|
|
|
MODEL = None |
|
|
TOKENIZER = None |
|
|
MODEL_NAME = None |
|
|
DEVICE = "cpu" |
|
|
MODEL_LOCK = threading.Lock() |
|
|
|
|
|
MODEL_LOADING = False |
|
|
|
|
|
USE_CHAT_TEMPLATE = False |
|
|
|
|
|
|
|
|
|
|
|
def _get_tok_vocab_size(tok: AutoTokenizer) -> Optional[int]: |
|
|
try: |
|
|
return int(getattr(tok, "vocab_size")) |
|
|
except Exception: |
|
|
try: |
|
|
return int(tok.get_vocab_size()) |
|
|
except Exception: |
|
|
return len(tok.get_vocab()) if hasattr(tok, "get_vocab") else None |
|
|
|
|
|
|
|
|
def _diagnose_and_fix_tokenizer_model(tok: AutoTokenizer, mdl: AutoModelForCausalLM): |
|
|
"""Fix common tokenizer<->model mismatches (SentencePiece piece-id edge-cases). |
|
|
This mirrors the notebook fixes so Spaces will not hit `piece id out of range`. |
|
|
""" |
|
|
tok_vs = _get_tok_vocab_size(tok) or 0 |
|
|
try: |
|
|
emb_rows = mdl.get_input_embeddings().weight.shape[0] |
|
|
except Exception: |
|
|
emb_rows = 0 |
|
|
|
|
|
special_ids = getattr(tok, "all_special_ids", []) or [] |
|
|
max_special_id = max(special_ids) if special_ids else 0 |
|
|
|
|
|
required = max(tok_vs, emb_rows, max_special_id + 1) |
|
|
|
|
|
|
|
|
if getattr(tok, "vocab_size", 0) < required: |
|
|
try: |
|
|
tok.vocab_size = required |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if emb_rows < required: |
|
|
try: |
|
|
mdl.resize_token_embeddings(required) |
|
|
mdl.config.vocab_size = required |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if getattr(tok, "pad_token", None) is None: |
|
|
tok.pad_token = getattr(tok, "eos_token", "[PAD]") |
|
|
|
|
|
try: |
|
|
tok.add_special_tokens({"pad_token": tok.pad_token}) |
|
|
except TypeError as e: |
|
|
|
|
|
try: |
|
|
tok.add_special_tokens([tok.pad_token]) |
|
|
except Exception: |
|
|
try: |
|
|
tok.add_tokens([tok.pad_token]) |
|
|
except Exception: |
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
pad_id = tok.convert_tokens_to_ids(tok.pad_token) |
|
|
tok.pad_token_id = pad_id |
|
|
mdl.config.pad_token_id = pad_id |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def _is_lfs_pointer_file(path: str) -> bool: |
|
|
try: |
|
|
with open(path, "rb") as f: |
|
|
start = f.read(128) |
|
|
return b"git-lfs.github.com/spec/v1" in start |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def _download_tokenizer_model_from_hub(hf_repo: str, dest_path: str, hf_token: Optional[str] = None) -> bool: |
|
|
"""Download tokenizer.model from HF Hub into dest_path. Returns True on success.""" |
|
|
try: |
|
|
import urllib.request |
|
|
|
|
|
url = f"https://huggingface.co/{hf_repo}/resolve/main/tokenizer.model" |
|
|
req = urllib.request.Request(url, headers={"User-Agent": "spaces-nanbeige-chat/1.0"}) |
|
|
if hf_token: |
|
|
req.add_header("Authorization", f"Bearer {hf_token}") |
|
|
with urllib.request.urlopen(req, timeout=30) as r, open(dest_path + ".tmp", "wb") as out: |
|
|
out.write(r.read()) |
|
|
os.replace(dest_path + ".tmp", dest_path) |
|
|
return True |
|
|
except Exception as e: |
|
|
print("_download_tokenizer_model_from_hub failed:", e) |
|
|
try: |
|
|
if os.path.exists(dest_path + ".tmp"): |
|
|
os.remove(dest_path + ".tmp") |
|
|
except Exception: |
|
|
pass |
|
|
return False |
|
|
|
|
|
|
|
|
def _ensure_local_tokenizer_model(repo_path: str, hf_token: Optional[str] = None) -> bool: |
|
|
"""If tokenizer.model in repo_path is a Git-LFS pointer, try to download the real file from the Hub. |
|
|
Tries to infer a Hub repo id from the local git remote; falls back to `PioTio/<dirname>` for Nanbeige folders. |
|
|
""" |
|
|
tm = os.path.join(repo_path, "tokenizer.model") |
|
|
if not os.path.exists(tm): |
|
|
return False |
|
|
if not _is_lfs_pointer_file(tm): |
|
|
return True |
|
|
|
|
|
|
|
|
repo_id = None |
|
|
try: |
|
|
import subprocess |
|
|
|
|
|
out = subprocess.check_output(["git", "-C", repo_path, "config", "--get", "remote.origin.url"], text=True).strip() |
|
|
if out and "huggingface.co" in out: |
|
|
|
|
|
parts = out.rstrip(".git").split("/") |
|
|
repo_id = f"{parts[-2]}/{parts[-1]}" |
|
|
except Exception: |
|
|
repo_id = None |
|
|
|
|
|
|
|
|
if repo_id is None: |
|
|
guessed = os.path.basename(repo_path) |
|
|
if guessed.lower().startswith("nanbeige") or "nanbeige" in guessed.lower(): |
|
|
repo_id = f"PioTio/{guessed}" |
|
|
|
|
|
if repo_id: |
|
|
return _download_tokenizer_model_from_hub(repo_id, tm, hf_token=hf_token) |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def _upload_tokenizer_files_to_hub(repo_id: str, local_tokenizer_dir: str, hf_token: Optional[str] = None) -> bool: |
|
|
"""Upload tokenizer files (tokenizer.model, tokenizer_config.json, tokenizer.json, special_tokens_map.json) |
|
|
Returns True if at least one file was uploaded successfully. |
|
|
""" |
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
api = HfApi() |
|
|
candidates = [ |
|
|
"tokenizer.model", |
|
|
"tokenizer_config.json", |
|
|
"tokenizer.json", |
|
|
"special_tokens_map.json", |
|
|
"chat_template.jinja", |
|
|
] |
|
|
uploaded = 0 |
|
|
for fn in candidates: |
|
|
p = os.path.join(local_tokenizer_dir, fn) |
|
|
if not os.path.exists(p): |
|
|
continue |
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=p, |
|
|
path_in_repo=fn, |
|
|
repo_id=repo_id, |
|
|
token=hf_token, |
|
|
commit_message=f"Auto-fix tokenizer: {fn}", |
|
|
) |
|
|
print(f"_upload_tokenizer_files_to_hub: uploaded {fn} to {repo_id}") |
|
|
uploaded += 1 |
|
|
except Exception as e: |
|
|
print(f"_upload_tokenizer_files_to_hub: failed to upload {fn}: {e}") |
|
|
return uploaded > 0 |
|
|
except Exception as e: |
|
|
print("_upload_tokenizer_files_to_hub failed:", e) |
|
|
return False |
|
|
|
|
|
|
|
|
def _repair_and_upload_tokenizer(repo_id: str, hf_token: Optional[str] = None) -> bool: |
|
|
"""Fetch the correct base tokenizer (Nanbeige4.1 if detected, otherwise DEFAULT_MODEL), |
|
|
then upload tokenizer files to the target repo. Returns True on success. |
|
|
""" |
|
|
try: |
|
|
base = "Nanbeige/Nanbeige4.1-3B" if "4.1" in repo_id.lower() else DEFAULT_MODEL |
|
|
from transformers import AutoTokenizer |
|
|
import tempfile, shutil |
|
|
tmp = tempfile.mkdtemp(prefix="tokenizer_fix_") |
|
|
tok = AutoTokenizer.from_pretrained(base, use_fast=False, trust_remote_code=True) |
|
|
tok.save_pretrained(tmp) |
|
|
ok = _upload_tokenizer_files_to_hub(repo_id, tmp, hf_token=hf_token) |
|
|
shutil.rmtree(tmp) |
|
|
return ok |
|
|
except Exception as e: |
|
|
print("_repair_and_upload_tokenizer failed:", e) |
|
|
return False |
|
|
|
|
|
|
|
|
def repair_tokenizer_on_hub(repo_id: str) -> str: |
|
|
"""Public helper callable from the UI: attempts to upload a working base tokenizer to `repo_id`. |
|
|
Requires HF_TOKEN in the environment with write access to the target repo. |
|
|
""" |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if not hf_token: |
|
|
return "HF_TOKEN not set — cannot upload tokenizer to Hub. Add HF_TOKEN and retry." |
|
|
try: |
|
|
ok = _repair_and_upload_tokenizer(repo_id, hf_token=hf_token) |
|
|
return "Uploaded tokenizer files to repo" if ok else "Repair attempt failed (see logs)" |
|
|
except Exception as e: |
|
|
return f"Repair failed: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_model_from_pretrained(repo_id, *args, **kwargs): |
|
|
"""Call AutoModelForCausalLM.from_pretrained but retry without `use_auth_token` |
|
|
if the called class improperly forwards unexpected kwargs into __init__. |
|
|
""" |
|
|
try: |
|
|
return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs) |
|
|
except TypeError as e: |
|
|
msg = str(e) |
|
|
if "use_auth_token" in msg or "unexpected keyword argument" in msg: |
|
|
|
|
|
kwargs2 = dict(kwargs) |
|
|
kwargs2.pop("use_auth_token", None) |
|
|
kwargs2.pop("token", None) |
|
|
print(f"_safe_model_from_pretrained: retrying without auth-token due to: {e}") |
|
|
return AutoModelForCausalLM.from_pretrained(repo_id, *args, **kwargs2) |
|
|
raise |
|
|
|
|
|
|
|
|
def load_model(repo_id: str = DEFAULT_MODEL, force_reload: bool = False) -> str: |
|
|
"""Load model + tokenizer from the Hub. Graceful fallbacks and HF-token support. |
|
|
|
|
|
Changes made: |
|
|
- prefer slow tokenizer (use_fast=False) |
|
|
- accept HF token via env HF_TOKEN for private repos / higher rate limits |
|
|
- fallback to base tokenizer (`PioTio/Nanbeige2.5`) when tokenizer files are missing |
|
|
- pass auth token into from_pretrained calls where supported |
|
|
""" |
|
|
global MODEL, TOKENIZER, MODEL_NAME, DEVICE |
|
|
|
|
|
with MODEL_LOCK: |
|
|
if MODEL is not None and MODEL_NAME == repo_id and not force_reload: |
|
|
return f"Model already loaded: {MODEL_NAME} (@ {DEVICE})" |
|
|
|
|
|
|
|
|
global MODEL_LOADING |
|
|
MODEL_LOADING = True |
|
|
print(f"Model load started: {repo_id}") |
|
|
|
|
|
MODEL = None |
|
|
TOKENIZER = None |
|
|
MODEL_NAME = repo_id |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
try: |
|
|
TOKENIZER = AutoTokenizer.from_pretrained( |
|
|
repo_id, |
|
|
use_fast=False, |
|
|
trust_remote_code=True, |
|
|
use_auth_token=hf_token, |
|
|
) |
|
|
print(f"Tokenizer loaded from repo: {repo_id}") |
|
|
|
|
|
try: |
|
|
global USE_CHAT_TEMPLATE |
|
|
USE_CHAT_TEMPLATE = hasattr(TOKENIZER, "apply_chat_template") |
|
|
print(f"USE_CHAT_TEMPLATE={USE_CHAT_TEMPLATE}") |
|
|
except Exception: |
|
|
USE_CHAT_TEMPLATE = False |
|
|
except Exception as e_tok: |
|
|
print(f"Tokenizer load from {repo_id} failed: {e_tok}") |
|
|
|
|
|
|
|
|
|
|
|
if "Input must be a List" in str(e_tok) or "Input must be a List[Union[str, AddedToken]]" in str(e_tok): |
|
|
try: |
|
|
print('Detected tokenizer add-tokens type error; attempting in-place normalization and retry...') |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
import json, tempfile, shutil |
|
|
|
|
|
tmp = tempfile.mkdtemp(prefix="tokfix_") |
|
|
|
|
|
candidates = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "added_tokens.json"] |
|
|
for fn in candidates: |
|
|
try: |
|
|
src = hf_hub_download(repo_id=repo_id, filename=fn, token=hf_token) |
|
|
shutil.copy(src, tmp) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
stm = os.path.join(tmp, "special_tokens_map.json") |
|
|
if os.path.exists(stm): |
|
|
try: |
|
|
with open(stm, "r", encoding="utf-8") as f: |
|
|
stm_j = json.load(f) |
|
|
changed = False |
|
|
if "additional_special_tokens" in stm_j: |
|
|
new = [] |
|
|
for it in stm_j["additional_special_tokens"]: |
|
|
if isinstance(it, dict): |
|
|
new.append(it.get("content") or it.get("token") or str(it)) |
|
|
changed = True |
|
|
else: |
|
|
new.append(it) |
|
|
stm_j["additional_special_tokens"] = new |
|
|
for k in ["bos_token", "eos_token", "pad_token", "unk_token"]: |
|
|
if k in stm_j and isinstance(stm_j[k], dict): |
|
|
stm_j[k] = stm_j[k].get("content", stm_j[k]) |
|
|
changed = True |
|
|
if changed: |
|
|
with open(stm, "w", encoding="utf-8") as f: |
|
|
json.dump(stm_j, f, ensure_ascii=False, indent=2) |
|
|
print('Normalized special_tokens_map.json in temp dir') |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
TOKENIZER = AutoTokenizer.from_pretrained(tmp, use_fast=False, trust_remote_code=True) |
|
|
print('Tokenizer reloaded from normalized temp copy') |
|
|
shutil.rmtree(tmp) |
|
|
except Exception as e_localnorm: |
|
|
print('In-place normalization retry failed:', e_localnorm) |
|
|
|
|
|
|
|
|
|
|
|
if hf_token: |
|
|
print('Attempting repo-side auto-repair/upload from base tokenizer...') |
|
|
_repair_and_upload_tokenizer(repo_id, hf_token=hf_token) |
|
|
TOKENIZER = AutoTokenizer.from_pretrained(repo_id, use_fast=False, trust_remote_code=True) |
|
|
print('Tokenizer reloaded after repo repair') |
|
|
else: |
|
|
|
|
|
raise RuntimeError('Normalization + auto-repair could not proceed (no HF_TOKEN)') |
|
|
except Exception as e_retry: |
|
|
print('Repair/retry failed:', e_retry) |
|
|
return f"Tokenizer load failed: {e_retry}" |
|
|
else: |
|
|
|
|
|
try: |
|
|
if os.path.isdir(repo_id) and _ensure_local_tokenizer_model(repo_id, hf_token=hf_token): |
|
|
print(f"Found LFS pointer at {repo_id}/tokenizer.model — fetched real tokenizer.model; retrying tokenizer load...") |
|
|
TOKENIZER = AutoTokenizer.from_pretrained( |
|
|
repo_id, |
|
|
use_fast=False, |
|
|
trust_remote_code=True, |
|
|
use_auth_token=hf_token, |
|
|
) |
|
|
print(f"Tokenizer loaded from local repo after fetching LFS: {repo_id}") |
|
|
else: |
|
|
|
|
|
local_fallback = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'Nanbeige4.1-3B')) |
|
|
if os.path.isdir(local_fallback): |
|
|
try: |
|
|
print(f"Attempting local workspace tokenizer fallback: {local_fallback}") |
|
|
TOKENIZER = AutoTokenizer.from_pretrained(local_fallback, use_fast=False, trust_remote_code=True) |
|
|
print(f"Tokenizer loaded from local workspace: {local_fallback}") |
|
|
except Exception as e_local: |
|
|
print(f"Local tokenizer fallback failed: {e_local}") |
|
|
raise e_local |
|
|
else: |
|
|
|
|
|
base = "Nanbeige/Nanbeige4.1-3B" if "4.1" in repo_id.lower() else "PioTio/Nanbeige2.5" |
|
|
print(f"Falling back to base tokenizer: {base}") |
|
|
TOKENIZER = AutoTokenizer.from_pretrained(base, use_fast=False, trust_remote_code=True, use_auth_token=hf_token) |
|
|
|
|
|
|
|
|
if hf_token: |
|
|
try: |
|
|
uploaded = _repair_and_upload_tokenizer(repo_id, hf_token=hf_token) |
|
|
print(f"Auto-repair attempt to {repo_id}: {'succeeded' if uploaded else 'no-change/failure'}") |
|
|
except Exception as e_rep: |
|
|
print(f"Auto-repair attempt failed: {e_rep}") |
|
|
except Exception as e_base: |
|
|
|
|
|
try: |
|
|
print(f"All fallbacks failed: {e_base}. Trying generic AutoTokenizer as last resort...") |
|
|
TOKENIZER = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True, use_auth_token=hf_token) |
|
|
except Exception as e_final: |
|
|
MODEL_LOADING = False |
|
|
return f"Tokenizer load failed: {e_final}" |
|
|
|
|
|
|
|
|
if DEVICE == "cuda" and HAS_BNB: |
|
|
try: |
|
|
bnb_config = BitsAndBytesConfig(load_in_4bit=True) |
|
|
MODEL = _safe_model_from_pretrained( |
|
|
repo_id, |
|
|
device_map="auto", |
|
|
quantization_config=bnb_config, |
|
|
trust_remote_code=True, |
|
|
use_auth_token=hf_token, |
|
|
) |
|
|
MODEL.eval() |
|
|
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL) |
|
|
MODEL_LOADING = False |
|
|
print(f"Model load finished (4-bit): {repo_id}") |
|
|
return f"Loaded {repo_id} (4-bit, device_map=auto)" |
|
|
except Exception as e: |
|
|
print("bnb/4bit load failed - falling back:", e) |
|
|
|
|
|
|
|
|
try: |
|
|
if DEVICE == "cuda": |
|
|
MODEL = _safe_model_from_pretrained( |
|
|
repo_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
use_auth_token=hf_token, |
|
|
) |
|
|
else: |
|
|
MODEL = _safe_model_from_pretrained( |
|
|
repo_id, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True, |
|
|
use_auth_token=hf_token, |
|
|
) |
|
|
MODEL.to("cpu") |
|
|
|
|
|
MODEL.eval() |
|
|
_diagnose_and_fix_tokenizer_model(TOKENIZER, MODEL) |
|
|
MODEL_LOADING = False |
|
|
print(f"Model load finished: {repo_id} (@{DEVICE})") |
|
|
return f"Loaded {repo_id} (@{DEVICE})" |
|
|
except Exception as e: |
|
|
MODEL = None |
|
|
TOKENIZER = None |
|
|
|
|
|
MODEL_LOADING = False |
|
|
print(f"Model load failed: {repo_id} -> {e}") |
|
|
return f"Model load failed: {e} (hint: check HF_TOKEN, repo contents and ensure tokenizer.model is present)" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALPACA_TMPL = ( |
|
|
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" |
|
|
"### Instruction:\n{}\n\n" |
|
|
"### Input:\n{}\n\n" |
|
|
"### Response:\n" |
|
|
) |
|
|
|
|
|
|
|
|
def _normalize_history(raw_history) -> List[Tuple[str, str]]: |
|
|
"""Accept either: |
|
|
- List of (user_str, assistant_str) tuples (legacy Gradio Chatbot) |
|
|
- List of dicts {role: "user"|"assistant"|"system", content: str} (new messages API) |
|
|
and return a list of (user, assistant) pairs suitable for prompt construction. |
|
|
|
|
|
Behavior: pairs each user message with the next assistant message (assistant may be "" if not present). |
|
|
NOTE: For chat-first models (Nanbeige4.1) we prefer `tokenizer.apply_chat_template` later |
|
|
so this function only normalizes the history shape. |
|
|
""" |
|
|
if not raw_history: |
|
|
return [] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in raw_history): |
|
|
return [(str(u or ""), str(a or "")) for u, a in raw_history] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
pairs: List[Tuple[str, str]] = [] |
|
|
pending_user: Optional[str] = None |
|
|
for item in raw_history: |
|
|
if isinstance(item, dict): |
|
|
role = item.get("role") or item.get("type") or "user" |
|
|
content = item.get("content") or item.get("value") or "" |
|
|
content = str(content or "") |
|
|
if role.lower() == "system": |
|
|
|
|
|
continue |
|
|
if role.lower() == "user": |
|
|
|
|
|
if pending_user is not None: |
|
|
pairs.append((pending_user, "")) |
|
|
pending_user = content |
|
|
elif role.lower() == "assistant": |
|
|
if pending_user is None: |
|
|
|
|
|
pairs.append(("", content)) |
|
|
else: |
|
|
pairs.append((pending_user, content)) |
|
|
pending_user = None |
|
|
else: |
|
|
|
|
|
s = str(item) |
|
|
if pending_user is not None: |
|
|
pairs.append((pending_user, "")) |
|
|
pending_user = s |
|
|
if pending_user is not None: |
|
|
pairs.append((pending_user, "")) |
|
|
return pairs |
|
|
|
|
|
|
|
|
def build_prompt(history, user_input: str, system_prompt: str, max_history: int = 6) -> str: |
|
|
|
|
|
pairs = _normalize_history(history or []) |
|
|
pairs = pairs[-max_history:] |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from __main__ import TOKENIZER |
|
|
except Exception: |
|
|
TOKENIZER = None |
|
|
|
|
|
if TOKENIZER is not None and hasattr(TOKENIZER, "apply_chat_template"): |
|
|
|
|
|
messages = [] |
|
|
if system_prompt: |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
for u, a in pairs: |
|
|
messages.append({"role": "user", "content": u}) |
|
|
if a: |
|
|
messages.append({"role": "assistant", "content": a}) |
|
|
|
|
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
try: |
|
|
prompt = TOKENIZER.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
return prompt |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
parts: List[str] = [f"System: {system_prompt}"] |
|
|
for u, a in pairs: |
|
|
|
|
|
parts.append(ALPACA_TMPL.format(u, "") + (a or "")) |
|
|
|
|
|
parts.append(ALPACA_TMPL.format(user_input, "")) |
|
|
return "\n\n".join(parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_text(prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int) -> str: |
|
|
global MODEL, TOKENIZER |
|
|
if MODEL is None or TOKENIZER is None: |
|
|
raise RuntimeError("Model is not loaded. Press 'Load model' first.") |
|
|
|
|
|
|
|
|
add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True |
|
|
|
|
|
input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
do_sample=True, |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
top_k=int(top_k), |
|
|
pad_token_id=TOKENIZER.eos_token_id or 0, |
|
|
eos_token_id=TOKENIZER.eos_token_id or None, |
|
|
) |
|
|
|
|
|
outputs = MODEL.generate(**gen_kwargs) |
|
|
|
|
|
gen_tokens = outputs[0][input_ids.shape[1] :] |
|
|
text = TOKENIZER.decode(gen_tokens, skip_special_tokens=True) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def _generate_stream(prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int): |
|
|
"""Yield partial outputs while the model generates (uses TextIteratorStreamer).""" |
|
|
global MODEL, TOKENIZER |
|
|
if MODEL is None or TOKENIZER is None: |
|
|
raise RuntimeError("Model is not loaded. Press 'Load model' first.") |
|
|
|
|
|
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True) |
|
|
add_special_tokens = False if hasattr(TOKENIZER, "apply_chat_template") else True |
|
|
input_ids = TOKENIZER(prompt, return_tensors="pt", truncation=True, max_length=2048, add_special_tokens=add_special_tokens).input_ids.to(next(MODEL.parameters()).device) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
do_sample=True, |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
top_k=int(top_k), |
|
|
pad_token_id=TOKENIZER.eos_token_id or 0, |
|
|
eos_token_id=TOKENIZER.eos_token_id or None, |
|
|
streamer=streamer, |
|
|
) |
|
|
|
|
|
thread = threading.Thread(target=MODEL.generate, kwargs=gen_kwargs) |
|
|
thread.start() |
|
|
|
|
|
out = "" |
|
|
for piece in streamer: |
|
|
out += piece |
|
|
yield out |
|
|
|
|
|
thread.join() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def submit_message(user_message: str, history, system_prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int, stream: bool, max_history: int, force_cpu: bool = False): |
|
|
"""Accepts history in either tuple-list form or messages-dict form and |
|
|
always returns `[(user, assistant), ...]` tuples for Gradio's Chatbot. |
|
|
""" |
|
|
raw_history = history or [] |
|
|
|
|
|
|
|
|
pairs = _normalize_history(raw_history) |
|
|
|
|
|
|
|
|
pairs.append((str(user_message or ""), "")) |
|
|
|
|
|
|
|
|
if MODEL_LOADING: |
|
|
pairs[-1] = (user_message, "⚠️ Model is still loading — please wait and try again. Check 'Status' for progress.") |
|
|
yield pairs, "" |
|
|
return |
|
|
|
|
|
if MODEL is None: |
|
|
pairs[-1] = (user_message, "⚠️ Model is not loaded — click 'Load model' first.") |
|
|
yield pairs, "" |
|
|
return |
|
|
|
|
|
prompt = build_prompt(pairs[:-1], user_message, system_prompt, max_history) |
|
|
|
|
|
|
|
|
if MODEL_NAME == DEFAULT_MODEL and DEVICE == "cpu" and not force_cpu: |
|
|
warning = ( |
|
|
"⚠️ **Nanbeige is too large for CPU inference and will be extremely slow.**\n\n" |
|
|
"Options:\n" |
|
|
"- Enable GPU in Space settings (recommended)\n" |
|
|
f"- Click **Load fast CPU demo ({CPU_DEMO_MODEL})** for a quick, low-cost demo\n" |
|
|
"- Or check 'Force CPU generation' to proceed on CPU (not recommended)") |
|
|
pairs[-1] = (user_message, warning) |
|
|
yield pairs, "" |
|
|
return |
|
|
|
|
|
if stream: |
|
|
|
|
|
for partial in _generate_stream(prompt, temperature, top_p, top_k, max_new_tokens): |
|
|
pairs[-1] = (user_message, partial) |
|
|
|
|
|
yield pairs, "" |
|
|
return |
|
|
|
|
|
try: |
|
|
out = _generate_text(prompt, temperature, top_p, top_k, max_new_tokens) |
|
|
except Exception as e: |
|
|
pairs[-1] = (user_message, f"<Error during generation: {e}>") |
|
|
return pairs, "" |
|
|
|
|
|
pairs[-1] = (user_message, out) |
|
|
yield pairs, "" |
|
|
|
|
|
|
|
|
def clear_chat() -> List[Tuple[str, str]]: |
|
|
return [] |
|
|
|
|
|
|
|
|
def regenerate(history, system_prompt: str, temperature: float, top_p: float, top_k: int, max_new_tokens: int, stream: bool, max_history: int, force_cpu: bool = False): |
|
|
if not history: |
|
|
return history, "" |
|
|
pairs = _normalize_history(history) |
|
|
|
|
|
last_user = pairs[-1][0] if pairs else "" |
|
|
return submit_message(last_user, pairs[:-1], system_prompt, temperature, top_p, top_k, max_new_tokens, stream, max_history, force_cpu) |
|
|
|
|
|
|
|
|
def load_model_ui(repo: str): |
|
|
status = load_model(repo, force_reload=True) |
|
|
try: |
|
|
suffix = " — chat-template detected" if USE_CHAT_TEMPLATE else "" |
|
|
except NameError: |
|
|
suffix = "" |
|
|
|
|
|
loaded = str(status).lower().startswith("loaded") |
|
|
from gradio import update as gr_update |
|
|
send_state = gr_update(interactive=loaded) |
|
|
return status + suffix, send_state |
|
|
|
|
|
|
|
|
def apply_lora_adapter(adapter_repo: str): |
|
|
if not HAS_PEFT: |
|
|
return "peft not installed in this environment. Add `peft` to requirements.txt to enable LoRA loading." |
|
|
global MODEL |
|
|
if MODEL is None: |
|
|
return "Load base model first." |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
try: |
|
|
|
|
|
MODEL = PeftModel.from_pretrained(MODEL, adapter_repo, use_auth_token=hf_token) |
|
|
return f"Applied LoRA adapter from {adapter_repo}" |
|
|
except Exception as e: |
|
|
return f"Failed to apply adapter: {e} (hint: check adapter name and HF_TOKEN)" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Nanbeige2.5 — Chat UI") as demo: |
|
|
gr.Markdown("# 🦙 Nanbeige2.5 — Chat (Hugging Face Space)\nA lightweight, streaming chat UI with tokenizer/model sanity checks and optional LoRA support.") |
|
|
|
|
|
with gr.Row(): |
|
|
model_input = gr.Textbox(value=DEFAULT_MODEL, label="Model repo (HF)", interactive=True) |
|
|
load_btn = gr.Button("Load model") |
|
|
repair_btn = gr.Button("Repair tokenizer on Hub") |
|
|
model_demo_btn = gr.Button(f"Load fast CPU demo ({CPU_DEMO_MODEL})") |
|
|
model_status = gr.Textbox(value="Model not loaded", label="Status", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
system_prompt = gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, label="System prompt (applies to all turns)", lines=2) |
|
|
|
|
|
chatbot = gr.Chatbot(label="Conversation") |
|
|
state = gr.State([]) |
|
|
|
|
|
with gr.Row(): |
|
|
txt = gr.Textbox(show_label=False, placeholder="Type your message and press Enter...", lines=2) |
|
|
send = gr.Button("Send") |
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
quick_hi = gr.Button("Hi") |
|
|
quick_joke = gr.Button("Tell me a joke") |
|
|
quick_help = gr.Button("What can you do?") |
|
|
quick_qora = gr.Button("Explain QLoRA") |
|
|
|
|
|
quick_hi.click(lambda: "Hi", outputs=txt) |
|
|
quick_joke.click(lambda: "Tell me a joke", outputs=txt) |
|
|
quick_help.click(lambda: "What can you do?", outputs=txt) |
|
|
quick_qora.click(lambda: "Explain QLoRA", outputs=txt) |
|
|
|
|
|
with gr.Row(): |
|
|
temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.01, label="Temperature") |
|
|
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") |
|
|
top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k") |
|
|
max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="Max new tokens") |
|
|
|
|
|
with gr.Row(): |
|
|
stream_toggle = gr.Checkbox(value=True, label="Stream responses (recommended)") |
|
|
force_cpu = gr.Checkbox(value=False, label="Force CPU generation (not recommended)") |
|
|
max_history = gr.Slider(1, 12, value=6, step=1, label="Max history turns") |
|
|
regen = gr.Button("Regenerate") |
|
|
|
|
|
with gr.Row(): |
|
|
adapter_box = gr.Textbox(value="", label="Optional LoRA adapter repo (HF) — leave blank if none") |
|
|
apply_adapter = gr.Button("Apply LoRA adapter") |
|
|
|
|
|
|
|
|
load_btn.click(fn=load_model_ui, inputs=model_input, outputs=[model_status, send]) |
|
|
repair_btn.click(fn=repair_tokenizer_on_hub, inputs=model_input, outputs=model_status) |
|
|
|
|
|
send.click( |
|
|
fn=submit_message, |
|
|
inputs=[txt, state, system_prompt, temperature, top_p, top_k, max_new_tokens, stream_toggle, max_history, force_cpu], |
|
|
outputs=[chatbot, txt], |
|
|
) |
|
|
txt.submit( |
|
|
fn=submit_message, |
|
|
inputs=[txt, state, system_prompt, temperature, top_p, top_k, max_new_tokens, stream_toggle, max_history, force_cpu], |
|
|
outputs=[chatbot, txt], |
|
|
) |
|
|
|
|
|
clear.click(fn=clear_chat, inputs=None, outputs=[chatbot, state]) |
|
|
|
|
|
regen.click( |
|
|
fn=regenerate, |
|
|
inputs=[state, system_prompt, temperature, top_p, top_k, max_new_tokens, stream_toggle, max_history, force_cpu], |
|
|
outputs=[chatbot, txt], |
|
|
) |
|
|
|
|
|
apply_adapter.click(fn=apply_lora_adapter, inputs=[adapter_box], outputs=[model_status]) |
|
|
|
|
|
|
|
|
def _bg_initial_load(): |
|
|
|
|
|
def _worker(): |
|
|
res = load_model(DEFAULT_MODEL, force_reload=False) |
|
|
try: |
|
|
|
|
|
from gradio import update as gr_update |
|
|
interactive = str(res).lower().startswith("loaded") |
|
|
send.update(interactive=interactive) |
|
|
except Exception: |
|
|
pass |
|
|
return res |
|
|
|
|
|
t = threading.Thread(target=_worker, daemon=True) |
|
|
t.start() |
|
|
return "Loading model in background..." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.environ.get("SKIP_AUTOLOAD", "0") == "1": |
|
|
model_status.value = "Auto-load skipped (SKIP_AUTOLOAD=1)" |
|
|
else: |
|
|
model_status.value = _bg_initial_load() |
|
|
|
|
|
try: |
|
|
send.update(interactive=False) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
**⚠️ If this Space is running on CPU, `PioTio/Nanbeige2.5` will be extremely slow.** |
|
|
- Enable GPU in Space Settings for real-time use. |
|
|
- Or click **Load fast CPU demo (distilgpt2)** for an immediate, low-cost demo reply. |
|
|
""") |
|
|
|
|
|
|
|
|
model_demo_btn.click(fn=lambda: load_model_ui(CPU_DEMO_MODEL), inputs=None, outputs=model_status) |
|
|
|
|
|
|
|
|
gr.Markdown("---\n**Tips:** select GPU hardware for smoother streaming and enable 4-bit bitsandbytes by installing `bitsandbytes` in `requirements.txt`.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0") |
|
|
|