Spaces:
Running
Running
| # RP-AI — Multi-model Gradio backend with lazy loading & model switching | |
| # | |
| # Loads models on demand. Switching models unloads the old one first. | |
| # Original architecture preserved: Gradio Server + plain HTML frontend. | |
| import os | |
| import gc | |
| import logging | |
| import threading | |
| from contextlib import nullcontext | |
| from typing import Generator, List, Dict, Optional | |
| import torch | |
| from fastapi.responses import HTMLResponse | |
| from gradio import Server | |
| from huggingface_hub import login | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from utils_chatbot import organize_messages | |
| from web_search import search as web_search_fn | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_MODEL = "DavidAU/LFM2.5-1.2B-Thinking-Claude-4.6-Opus-Heretic-Uncensored-DISTILL" | |
| # Device detection | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| HAS_CUDA = DEVICE == "cuda" | |
| logger.info("Running on device: %s", DEVICE.upper()) | |
| if HAS_CUDA: | |
| try: | |
| import spaces # noqa: F401 | |
| _spaces_available = True | |
| except Exception: | |
| _spaces_available = False | |
| else: | |
| _spaces_available = False | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| logger.info("Logged in to Hugging Face Hub") | |
| else: | |
| logger.warning("HF_TOKEN not set — private/gated models will be inaccessible") | |
| _dtype = torch.bfloat16 if HAS_CUDA else torch.float32 | |
| _MAX_NEW_TOKENS = 4096 if HAS_CUDA else 1024 | |
| # ── Lazy-loaded model state ── | |
| _tokenizer = None | |
| _model = None | |
| _current_model_id = None | |
| _load_lock = threading.Lock() | |
| _load_in_progress = False | |
| def _unload_model(): | |
| """Free GPU/CPU memory from the current model.""" | |
| global _tokenizer, _model, _current_model_id | |
| if _model is not None: | |
| del _model | |
| _model = None | |
| if _tokenizer is not None: | |
| del _tokenizer | |
| _tokenizer = None | |
| _current_model_id = None | |
| gc.collect() | |
| if HAS_CUDA: | |
| torch.cuda.empty_cache() | |
| logger.info("Previous model unloaded.") | |
| def _load_model(model_id: str): | |
| """Load tokenizer + model on demand. Thread-safe; only runs once per model_id.""" | |
| global _tokenizer, _model, _current_model_id, _load_in_progress | |
| if _model is not None and _current_model_id == model_id: | |
| return _tokenizer, _model | |
| with _load_lock: | |
| if _model is not None and _current_model_id == model_id: | |
| return _tokenizer, _model | |
| _load_in_progress = True | |
| # Unload previous model if different | |
| if _current_model_id and _current_model_id != model_id: | |
| _unload_model() | |
| logger.info("Loading tokenizer from %s ...", model_id) | |
| _tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| logger.info("Loading model from %s on %s (%s) ...", model_id, DEVICE, _dtype) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=_dtype, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ).to(DEVICE) | |
| _model.eval() | |
| _current_model_id = model_id | |
| _load_in_progress = False | |
| logger.info("Model %s loaded successfully.", model_id) | |
| return _tokenizer, _model | |
| def _maybe_gpu(duration: int): | |
| """Apply `@spaces.GPU(duration=...)` only when running on CUDA + HF Spaces.""" | |
| def decorator(fn): | |
| if HAS_CUDA and _spaces_available: | |
| import spaces | |
| return spaces.GPU(duration=duration)(fn) | |
| return fn | |
| return decorator | |
| demo = Server() | |
| def search(query: str, num_results: int = 5) -> List[Dict[str, str]]: | |
| """Server-side web search using DuckDuckGo HTML.""" | |
| return web_search_fn(query, num_results=num_results) | |
| def status() -> Dict[str, str]: | |
| """Lightweight endpoint for frontend to check model readiness.""" | |
| return { | |
| "device": DEVICE, | |
| "model_id": _current_model_id or DEFAULT_MODEL, | |
| "model_loaded": _model is not None, | |
| "load_in_progress": _load_in_progress, | |
| "max_new_tokens": str(_MAX_NEW_TOKENS), | |
| } | |
| def switch_model(model_id: str) -> Dict[str, str]: | |
| """Switch to a different model. The actual load happens lazily on next predict.""" | |
| global _current_model_id | |
| _unload_model() | |
| logger.info("Model switch requested to: %s", model_id) | |
| return {"status": "ok", "new_model": model_id, "model_loaded": False} | |
| def predict( | |
| message: str, | |
| history: list[list] | None = None, | |
| thinking_mode: bool = True, | |
| temperature: float = 0.9, | |
| top_p: float = 0.95, | |
| system_prompt: str = "", | |
| web_context: str = "", | |
| ) -> Generator[str, None, None]: | |
| model_id = _current_model_id or DEFAULT_MODEL | |
| tokenizer, model = _load_model(model_id) | |
| messages = organize_messages( | |
| message, | |
| history, | |
| system_prompt=system_prompt, | |
| web_context=web_context, | |
| ) | |
| # Try chat template with thinking support; fall back to basic template | |
| try: | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=thinking_mode, | |
| ) | |
| except TypeError: | |
| # Model doesn't support enable_thinking kwarg | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer([prompt_text], return_tensors="pt").to(DEVICE) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=False, | |
| ) | |
| gen_kwargs = dict( | |
| **model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=_MAX_NEW_TOKENS, | |
| ) | |
| if temperature > 0: | |
| gen_kwargs.update(temperature=temperature, top_p=top_p, do_sample=True) | |
| else: | |
| gen_kwargs.update(do_sample=False) | |
| cm = torch.inference_mode() if not HAS_CUDA else nullcontext() | |
| with cm: | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| full_text = "" | |
| for new_token_text in streamer: | |
| if not new_token_text: | |
| continue | |
| full_text += new_token_text | |
| yield full_text | |
| thread.join() | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) |