Sevaai2 / app.py
SEVAQWERTY's picture
Update app.py
7eac325 verified
"""
🤖 Локальный чат‑бот (Transformers + Gradio) – без токена HF.
* Все зависимости (torch, transformers, gradio) ставятся в начале скрипта,
но фиксируются на совместимых версиях.
* В списке только **публичные** модели (не‑gated).
Если пользователь попытается выбрать закрытую модель – получит
дружелюбное сообщение с предложением использовать одну из доступных.
* Токен не нужен, модель загружается локально.
* UI: явная кнопка «Send», кнопка «Stop», сохранение истории, прогресс‑спиннер.
"""
# ----------------------------------------------------------------------
# 0️⃣ Bootstrap‑установка зависимостей (фикс‑версии)
# ----------------------------------------------------------------------
import sys, subprocess, importlib, logging
def _install(pkg: str) -> None:
subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
def _ensure(module: str, spec: str) -> None:
"""Импортировать `module`; если не найден – установить `spec` (может отличаться)."""
try:
importlib.import_module(module)
except ModuleNotFoundError:
print(f"[bootstrap] Installing {spec} …")
_install(spec)
importlib.import_module(module)
# Выбираем версии, которые точно работают друг с другом
_ensure("gradio", "gradio>=5.0")
_ensure("torch", "torch") # CPU‑wheel (или GPU‑wheel, если Space в GPU)
_ensure("huggingface_hub", "huggingface_hub==0.20.3") # версия, где ещё есть DryRunError
_ensure("transformers", "transformers==4.39.2") # совместима с huggingface_hub‑0.20.3
# ----------------------------------------------------------------------
# 1️⃣ Обычные импорты
# ----------------------------------------------------------------------
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from typing import Dict, Generator, List # ← без этого был NameError
# ----------------------------------------------------------------------
# 2️⃣ Открытые модели (человекочитаемое имя → HF‑repo‑id)
# ----------------------------------------------------------------------
MODEL_CHOICES = {
# 2‑Б‑модели – легко помещаются в CPU / небольшие GPU
"Mistral‑7B Instruct – fast": "mistralai/Mistral-7B-Instruct-v0.2",
"Phi‑3‑mini‑4k – tiny & fast": "microsoft/Phi-3-mini-4k-instruct",
"Llama‑3 8B Instruct – balanced": "meta-llama/Meta-Llama-3-8B-Instruct",
# Если ваш Space имеет мощный GPU, вы можете добавить более крупные модели,
# но они тоже должны быть **публичными** (не gated).
}
DEFAULT_MODEL = "Mistral‑7B Instruct – fast"
# ----------------------------------------------------------------------
# 3️⃣ Кеш моделей и токенизаторов (чтобы не загружать их каждый запрос)
# ----------------------------------------------------------------------
_model_cache: dict[str, AutoModelForCausalLM] = {}
_tokenizer_cache: dict[str, AutoTokenizer] = {}
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
_logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_logger.info("Запускаем на устройстве: %s", DEVICE)
def load_model(model_id: str) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Загружает (или берёт из кеша) модель и токенизатор."""
if model_id not in _model_cache:
_logger.info("Скачиваем модель %s …", model_id)
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if DEVICE.type == "cuda" else torch.float32,
device_map="auto",
).eval()
except Exception as exc:
# Пробрасываем, чтобы вызывающий код мог отреагировать
raise RuntimeError(f"Не удалось загрузить модель `{model_id}`: {exc}") from exc
_model_cache[model_id] = model
else:
model = _model_cache[model_id]
if model_id not in _tokenizer_cache:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
except Exception as exc:
raise RuntimeError(f"Не удалось загрузить токенизатор `{model_id}`: {exc}") from exc
_tokenizer_cache[model_id] = tokenizer
else:
tokenizer = _tokenizer_cache[model_id]
return model, tokenizer
def build_prompt(
tokenizer: AutoTokenizer,
system_msg: str,
history: List[Dict[str, str]],
user_msg: str,
) -> torch.LongTensor:
"""Создаёт `input_ids` из system‑prompt, истории и текущего сообщения."""
messages = [{"role": "system", "content": system_msg}]
messages += history # уже в формате OpenAI‑style
messages.append({"role": "user", "content": user_msg})
if hasattr(tokenizer, "apply_chat_template"):
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(DEVICE)
else:
# fallback‑вариант (подходит почти всем моделям)
txt = f"{system_msg}\n\n"
for m in messages:
role, cont = m["role"], m["content"]
prefix = "User: " if role == "user" else "Assistant: "
txt += f"{prefix}{cont}\n"
txt += "Assistant: "
input_ids = tokenizer(txt, return_tensors="pt").input_ids.to(DEVICE)
return input_ids
def respond(
user_msg: str,
chat_history: List[Dict[str, str]],
model_name: str,
system_msg: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> Generator[str, None, None]:
"""Стримит ответ выбранной модели."""
model_id = MODEL_CHOICES.get(model_name)
if not model_id:
yield f"⚠️ Неизвестная модель: {model_name}"
return
# ----------------- загрузка модели -----------------
try:
model, tokenizer = load_model(model_id)
except RuntimeError as exc:
# Здесь появляется «403/401 gated repo» или другие ошибки
_logger.exception("Ошибка при загрузке модели")
# Дружелюбное сообщение пользователю
hint = (
"Эта модель требует авторизации (закрытый репозиторий). "
"Вы можете выбрать одну из публичных моделей:\n"
+ "\n".join(f"• {name}" for name in MODEL_CHOICES)
)
yield f"⚠️ {exc}\n\n{hint}"
return
# ----------------- построение prompt -----------------
try:
input_ids = build_prompt(tokenizer, system_msg, chat_history, user_msg)
except Exception as exc:
_logger.exception("Ошибка формирования prompt")
yield f"⚠️ Не удалось собрать запрос к модели: {exc}"
return
# ----------------- генерация (стрим) -----------------
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=60.0,
)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0.0,
streamer=streamer,
)
import threading
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
answer = ""
try:
for new_text in streamer:
answer += new_text
yield answer
except Exception as exc:
_logger.exception("Ошибка генерации")
yield f"⚠️ Ошибка модели: {exc}"
finally:
if thread.is_alive():
# Останавливаем стример, если пользователь нажал «Stop».
streamer.end()
thread.join(timeout=0.1)
# ----------------------------------------------------------------------
# 4️⃣ UI‑сборка (Gradio)
# ----------------------------------------------------------------------
with gr.Blocks() as demo:
# Боковая панель только с описанием – токен не нужен
with gr.Sidebar():
gr.Markdown(
"""
# Локальный чат‑бот
Токен Hugging Face **не требуется** – модель загружается и работает
полностью в контейнере Space.
Выберите одну из **публичных** моделей и настройте параметры генерации.
"""
)
# Основной чат‑интерфейс
chatbot = gr.ChatInterface(
fn=respond,
type="messages",
additional_inputs=[
# 1️⃣ Модель
gr.Dropdown(
choices=list(MODEL_CHOICES.keys()),
value=DEFAULT_MODEL,
label="Модель",
info="Выберите модель из списка (все публичные).",
),
# 2️⃣ System‑prompt
gr.Textbox(
value="You are a friendly Chatbot.",
label="System message",
placeholder="Опишите роль ассистента (system prompt).",
),
# 3️⃣ Max new tokens
gr.Slider(
minimum=1,
maximum=2048,
step=1,
value=512,
label="Max new tokens",
),
# 4️⃣ Temperature
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.7,
label="Temperature",
),
# 5️⃣ Top‑p (nucleus sampling)
gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.05,
value=0.95,
label="Top‑p (nucleus sampling)",
),
],
title="🤖 Локальный чат‑бот (Transformers + Gradio)",
description=(
"История сохраняется в браузере, есть кнопка «Stop», а «Send» явно видна. "
"Модель кешируется в памяти после первой загрузки."
),
submit_btn="Send",
stop_btn=True,
save_history=True,
show_progress="full",
)
chatbot.render()
# Показать полный traceback в UI, если что‑то сломается
if __name__ == "__main__":
demo.launch(show_error=True)