Spaces:
Sleeping
Sleeping
| """ | |
| 🤖 Локальный чат‑бот (Transformers + Gradio) – без токена HF. | |
| * Все зависимости (torch, transformers, gradio) ставятся в начале скрипта, | |
| но фиксируются на совместимых версиях. | |
| * В списке только **публичные** модели (не‑gated). | |
| Если пользователь попытается выбрать закрытую модель – получит | |
| дружелюбное сообщение с предложением использовать одну из доступных. | |
| * Токен не нужен, модель загружается локально. | |
| * UI: явная кнопка «Send», кнопка «Stop», сохранение истории, прогресс‑спиннер. | |
| """ | |
| # ---------------------------------------------------------------------- | |
| # 0️⃣ Bootstrap‑установка зависимостей (фикс‑версии) | |
| # ---------------------------------------------------------------------- | |
| import sys, subprocess, importlib, logging | |
| def _install(pkg: str) -> None: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", pkg]) | |
| def _ensure(module: str, spec: str) -> None: | |
| """Импортировать `module`; если не найден – установить `spec` (может отличаться).""" | |
| try: | |
| importlib.import_module(module) | |
| except ModuleNotFoundError: | |
| print(f"[bootstrap] Installing {spec} …") | |
| _install(spec) | |
| importlib.import_module(module) | |
| # Выбираем версии, которые точно работают друг с другом | |
| _ensure("gradio", "gradio>=5.0") | |
| _ensure("torch", "torch") # CPU‑wheel (или GPU‑wheel, если Space в GPU) | |
| _ensure("huggingface_hub", "huggingface_hub==0.20.3") # версия, где ещё есть DryRunError | |
| _ensure("transformers", "transformers==4.39.2") # совместима с huggingface_hub‑0.20.3 | |
| # ---------------------------------------------------------------------- | |
| # 1️⃣ Обычные импорты | |
| # ---------------------------------------------------------------------- | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from typing import Dict, Generator, List # ← без этого был NameError | |
| # ---------------------------------------------------------------------- | |
| # 2️⃣ Открытые модели (человекочитаемое имя → HF‑repo‑id) | |
| # ---------------------------------------------------------------------- | |
| MODEL_CHOICES = { | |
| # 2‑Б‑модели – легко помещаются в CPU / небольшие GPU | |
| "Mistral‑7B Instruct – fast": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "Phi‑3‑mini‑4k – tiny & fast": "microsoft/Phi-3-mini-4k-instruct", | |
| "Llama‑3 8B Instruct – balanced": "meta-llama/Meta-Llama-3-8B-Instruct", | |
| # Если ваш Space имеет мощный GPU, вы можете добавить более крупные модели, | |
| # но они тоже должны быть **публичными** (не gated). | |
| } | |
| DEFAULT_MODEL = "Mistral‑7B Instruct – fast" | |
| # ---------------------------------------------------------------------- | |
| # 3️⃣ Кеш моделей и токенизаторов (чтобы не загружать их каждый запрос) | |
| # ---------------------------------------------------------------------- | |
| _model_cache: dict[str, AutoModelForCausalLM] = {} | |
| _tokenizer_cache: dict[str, AutoTokenizer] = {} | |
| logging.basicConfig(level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(message)s") | |
| _logger = logging.getLogger(__name__) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _logger.info("Запускаем на устройстве: %s", DEVICE) | |
| def load_model(model_id: str) -> tuple[AutoModelForCausalLM, AutoTokenizer]: | |
| """Загружает (или берёт из кеша) модель и токенизатор.""" | |
| if model_id not in _model_cache: | |
| _logger.info("Скачиваем модель %s …", model_id) | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if DEVICE.type == "cuda" else torch.float32, | |
| device_map="auto", | |
| ).eval() | |
| except Exception as exc: | |
| # Пробрасываем, чтобы вызывающий код мог отреагировать | |
| raise RuntimeError(f"Не удалось загрузить модель `{model_id}`: {exc}") from exc | |
| _model_cache[model_id] = model | |
| else: | |
| model = _model_cache[model_id] | |
| if model_id not in _tokenizer_cache: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
| except Exception as exc: | |
| raise RuntimeError(f"Не удалось загрузить токенизатор `{model_id}`: {exc}") from exc | |
| _tokenizer_cache[model_id] = tokenizer | |
| else: | |
| tokenizer = _tokenizer_cache[model_id] | |
| return model, tokenizer | |
| def build_prompt( | |
| tokenizer: AutoTokenizer, | |
| system_msg: str, | |
| history: List[Dict[str, str]], | |
| user_msg: str, | |
| ) -> torch.LongTensor: | |
| """Создаёт `input_ids` из system‑prompt, истории и текущего сообщения.""" | |
| messages = [{"role": "system", "content": system_msg}] | |
| messages += history # уже в формате OpenAI‑style | |
| messages.append({"role": "user", "content": user_msg}) | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| else: | |
| # fallback‑вариант (подходит почти всем моделям) | |
| txt = f"{system_msg}\n\n" | |
| for m in messages: | |
| role, cont = m["role"], m["content"] | |
| prefix = "User: " if role == "user" else "Assistant: " | |
| txt += f"{prefix}{cont}\n" | |
| txt += "Assistant: " | |
| input_ids = tokenizer(txt, return_tensors="pt").input_ids.to(DEVICE) | |
| return input_ids | |
| def respond( | |
| user_msg: str, | |
| chat_history: List[Dict[str, str]], | |
| model_name: str, | |
| system_msg: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> Generator[str, None, None]: | |
| """Стримит ответ выбранной модели.""" | |
| model_id = MODEL_CHOICES.get(model_name) | |
| if not model_id: | |
| yield f"⚠️ Неизвестная модель: {model_name}" | |
| return | |
| # ----------------- загрузка модели ----------------- | |
| try: | |
| model, tokenizer = load_model(model_id) | |
| except RuntimeError as exc: | |
| # Здесь появляется «403/401 gated repo» или другие ошибки | |
| _logger.exception("Ошибка при загрузке модели") | |
| # Дружелюбное сообщение пользователю | |
| hint = ( | |
| "Эта модель требует авторизации (закрытый репозиторий). " | |
| "Вы можете выбрать одну из публичных моделей:\n" | |
| + "\n".join(f"• {name}" for name in MODEL_CHOICES) | |
| ) | |
| yield f"⚠️ {exc}\n\n{hint}" | |
| return | |
| # ----------------- построение prompt ----------------- | |
| try: | |
| input_ids = build_prompt(tokenizer, system_msg, chat_history, user_msg) | |
| except Exception as exc: | |
| _logger.exception("Ошибка формирования prompt") | |
| yield f"⚠️ Не удалось собрать запрос к модели: {exc}" | |
| return | |
| # ----------------- генерация (стрим) ----------------- | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| timeout=60.0, | |
| ) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0.0, | |
| streamer=streamer, | |
| ) | |
| import threading | |
| thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| answer = "" | |
| try: | |
| for new_text in streamer: | |
| answer += new_text | |
| yield answer | |
| except Exception as exc: | |
| _logger.exception("Ошибка генерации") | |
| yield f"⚠️ Ошибка модели: {exc}" | |
| finally: | |
| if thread.is_alive(): | |
| # Останавливаем стример, если пользователь нажал «Stop». | |
| streamer.end() | |
| thread.join(timeout=0.1) | |
| # ---------------------------------------------------------------------- | |
| # 4️⃣ UI‑сборка (Gradio) | |
| # ---------------------------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| # Боковая панель только с описанием – токен не нужен | |
| with gr.Sidebar(): | |
| gr.Markdown( | |
| """ | |
| # Локальный чат‑бот | |
| Токен Hugging Face **не требуется** – модель загружается и работает | |
| полностью в контейнере Space. | |
| Выберите одну из **публичных** моделей и настройте параметры генерации. | |
| """ | |
| ) | |
| # Основной чат‑интерфейс | |
| chatbot = gr.ChatInterface( | |
| fn=respond, | |
| type="messages", | |
| additional_inputs=[ | |
| # 1️⃣ Модель | |
| gr.Dropdown( | |
| choices=list(MODEL_CHOICES.keys()), | |
| value=DEFAULT_MODEL, | |
| label="Модель", | |
| info="Выберите модель из списка (все публичные).", | |
| ), | |
| # 2️⃣ System‑prompt | |
| gr.Textbox( | |
| value="You are a friendly Chatbot.", | |
| label="System message", | |
| placeholder="Опишите роль ассистента (system prompt).", | |
| ), | |
| # 3️⃣ Max new tokens | |
| gr.Slider( | |
| minimum=1, | |
| maximum=2048, | |
| step=1, | |
| value=512, | |
| label="Max new tokens", | |
| ), | |
| # 4️⃣ Temperature | |
| gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=0.7, | |
| label="Temperature", | |
| ), | |
| # 5️⃣ Top‑p (nucleus sampling) | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.95, | |
| label="Top‑p (nucleus sampling)", | |
| ), | |
| ], | |
| title="🤖 Локальный чат‑бот (Transformers + Gradio)", | |
| description=( | |
| "История сохраняется в браузере, есть кнопка «Stop», а «Send» явно видна. " | |
| "Модель кешируется в памяти после первой загрузки." | |
| ), | |
| submit_btn="Send", | |
| stop_btn=True, | |
| save_history=True, | |
| show_progress="full", | |
| ) | |
| chatbot.render() | |
| # Показать полный traceback в UI, если что‑то сломается | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |