Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
|
|
|
| 4 |
from typing import List, Dict, Tuple, Iterable, Optional, Any
|
| 5 |
|
| 6 |
import gradio as gr
|
|
@@ -53,7 +54,7 @@ def _load_encoder(name: str):
|
|
| 53 |
|
| 54 |
|
| 55 |
def _load_seq2seq(name: str):
|
| 56 |
-
# критично для T5/MT5: use_fast=False
|
| 57 |
tok = AutoTokenizer.from_pretrained(name, use_fast=False)
|
| 58 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 59 |
name,
|
|
@@ -68,7 +69,6 @@ print("Loading models...")
|
|
| 68 |
emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
|
| 69 |
gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
|
| 70 |
|
| 71 |
-
# fallback лениво (экономия памяти)
|
| 72 |
fb_tok = None
|
| 73 |
fb_model = None
|
| 74 |
print("Models loaded.")
|
|
@@ -272,7 +272,6 @@ def _ensure_embeddings_size(tokenizer, model, required_size: int):
|
|
| 272 |
cur = int(emb.num_embeddings)
|
| 273 |
if required_size > cur:
|
| 274 |
model.resize_token_embeddings(required_size)
|
| 275 |
-
# на всякий случай вернём на нужное устройство после ресайза
|
| 276 |
model.to(DEVICE)
|
| 277 |
|
| 278 |
|
|
@@ -289,17 +288,13 @@ def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, m
|
|
| 289 |
max_length=max_input_tokens,
|
| 290 |
)
|
| 291 |
|
| 292 |
-
# ВАЖНО: before .to(DEVICE) можно посчитать max_id на CPU
|
| 293 |
input_ids = batch["input_ids"]
|
| 294 |
max_id = int(input_ids.max().item()) if input_ids.numel() else 0
|
| 295 |
needed = max(int(len(tokenizer)), max_id + 1)
|
| 296 |
-
|
| 297 |
_ensure_embeddings_size(tokenizer, model, needed)
|
| 298 |
|
| 299 |
-
# После возможного resize — переносим на устройство
|
| 300 |
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
| 301 |
|
| 302 |
-
# Доп. страховка: если по какой-то причине всё ещё OOR — зажмём
|
| 303 |
emb_size = int(model.get_input_embeddings().num_embeddings)
|
| 304 |
if int(batch["input_ids"].max().item()) >= emb_size:
|
| 305 |
batch["input_ids"] = batch["input_ids"].clamp_max(emb_size - 1)
|
|
@@ -314,7 +309,6 @@ def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, m
|
|
| 314 |
early_stopping=True,
|
| 315 |
)
|
| 316 |
except IndexError:
|
| 317 |
-
# retry: синхронизируем по len(tokenizer) и повторяем
|
| 318 |
_ensure_embeddings_size(tokenizer, model, int(len(tokenizer)))
|
| 319 |
out_ids = model.generate(
|
| 320 |
**batch,
|
|
@@ -479,22 +473,42 @@ def generate_questions(difficulty: str, num_q: int, state: Dict[str, Any]) -> st
|
|
| 479 |
|
| 480 |
|
| 481 |
# =======================
|
| 482 |
-
# ЧАТ
|
| 483 |
# =======================
|
| 484 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
q = (message or "").strip()
|
| 486 |
if not q:
|
| 487 |
return chat_history, ""
|
| 488 |
|
| 489 |
if not state or not state.get("chunks") or state.get("embeddings") is None:
|
| 490 |
-
return chat_history
|
| 491 |
|
| 492 |
chunks: List[str] = state["chunks"]
|
| 493 |
emb: np.ndarray = state["embeddings"]
|
| 494 |
|
| 495 |
top_idx, best_sim = retrieve_topk(q, emb, top_k=4)
|
| 496 |
if best_sim < RETRIEVE_MIN_SIM:
|
| 497 |
-
return chat_history
|
| 498 |
|
| 499 |
ctx_idx = []
|
| 500 |
for i in top_idx:
|
|
@@ -520,7 +534,7 @@ def chat_answer(message: str, chat_history: List[Tuple[str, str]], state: Dict[s
|
|
| 520 |
a = "В документе нет информации для ответа на этот вопрос."
|
| 521 |
|
| 522 |
cleanup_memory()
|
| 523 |
-
return chat_history
|
| 524 |
|
| 525 |
|
| 526 |
def clear_chat():
|
|
@@ -631,7 +645,7 @@ with gr.Blocks(title="EduMultiSpace") as demo:
|
|
| 631 |
q_btn.click(generate_questions, inputs=[diff, n_q, state], outputs=[q_out])
|
| 632 |
|
| 633 |
with gr.Tab("Чат"):
|
| 634 |
-
chat = gr.Chatbot(label="Чат")
|
| 635 |
msg = gr.Textbox(lines=2, label="Вопрос")
|
| 636 |
send = gr.Button("Отправить")
|
| 637 |
clear = gr.Button("Очистить")
|
|
@@ -640,6 +654,33 @@ with gr.Blocks(title="EduMultiSpace") as demo:
|
|
| 640 |
clear.click(clear_chat, inputs=None, outputs=[chat, msg])
|
| 641 |
|
| 642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
if __name__ == "__main__":
|
| 644 |
-
|
| 645 |
-
demo.queue(concurrency_count=1, max_size=16).launch()
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
| 4 |
+
import inspect
|
| 5 |
from typing import List, Dict, Tuple, Iterable, Optional, Any
|
| 6 |
|
| 7 |
import gradio as gr
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def _load_seq2seq(name: str):
|
| 57 |
+
# критично для T5/MT5: use_fast=False (SentencePiece)
|
| 58 |
tok = AutoTokenizer.from_pretrained(name, use_fast=False)
|
| 59 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 60 |
name,
|
|
|
|
| 69 |
emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
|
| 70 |
gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
|
| 71 |
|
|
|
|
| 72 |
fb_tok = None
|
| 73 |
fb_model = None
|
| 74 |
print("Models loaded.")
|
|
|
|
| 272 |
cur = int(emb.num_embeddings)
|
| 273 |
if required_size > cur:
|
| 274 |
model.resize_token_embeddings(required_size)
|
|
|
|
| 275 |
model.to(DEVICE)
|
| 276 |
|
| 277 |
|
|
|
|
| 288 |
max_length=max_input_tokens,
|
| 289 |
)
|
| 290 |
|
|
|
|
| 291 |
input_ids = batch["input_ids"]
|
| 292 |
max_id = int(input_ids.max().item()) if input_ids.numel() else 0
|
| 293 |
needed = max(int(len(tokenizer)), max_id + 1)
|
|
|
|
| 294 |
_ensure_embeddings_size(tokenizer, model, needed)
|
| 295 |
|
|
|
|
| 296 |
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
| 297 |
|
|
|
|
| 298 |
emb_size = int(model.get_input_embeddings().num_embeddings)
|
| 299 |
if int(batch["input_ids"].max().item()) >= emb_size:
|
| 300 |
batch["input_ids"] = batch["input_ids"].clamp_max(emb_size - 1)
|
|
|
|
| 309 |
early_stopping=True,
|
| 310 |
)
|
| 311 |
except IndexError:
|
|
|
|
| 312 |
_ensure_embeddings_size(tokenizer, model, int(len(tokenizer)))
|
| 313 |
out_ids = model.generate(
|
| 314 |
**batch,
|
|
|
|
| 473 |
|
| 474 |
|
| 475 |
# =======================
|
| 476 |
+
# ЧАТ (messages)
|
| 477 |
# =======================
|
| 478 |
+
def _append_messages(history: Any, user_text: str, assistant_text: str) -> List[Dict[str, str]]:
|
| 479 |
+
if not history:
|
| 480 |
+
history = []
|
| 481 |
+
# если вдруг пришли tuples — конвертируем
|
| 482 |
+
if isinstance(history, list) and history and isinstance(history[0], (tuple, list)) and len(history[0]) == 2:
|
| 483 |
+
msgs: List[Dict[str, str]] = []
|
| 484 |
+
for u, a in history:
|
| 485 |
+
msgs.append({"role": "user", "content": str(u)})
|
| 486 |
+
msgs.append({"role": "assistant", "content": str(a)})
|
| 487 |
+
history = msgs
|
| 488 |
+
# если уже messages
|
| 489 |
+
if isinstance(history, list) and (not history or isinstance(history[0], dict)):
|
| 490 |
+
history = list(history)
|
| 491 |
+
history.append({"role": "user", "content": user_text})
|
| 492 |
+
history.append({"role": "assistant", "content": assistant_text})
|
| 493 |
+
return history
|
| 494 |
+
# fallback
|
| 495 |
+
return [{"role": "user", "content": user_text}, {"role": "assistant", "content": assistant_text}]
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def chat_answer(message: str, chat_history: List[Dict[str, str]], state: Dict[str, Any]):
|
| 499 |
q = (message or "").strip()
|
| 500 |
if not q:
|
| 501 |
return chat_history, ""
|
| 502 |
|
| 503 |
if not state or not state.get("chunks") or state.get("embeddings") is None:
|
| 504 |
+
return _append_messages(chat_history, q, "Сначала загрузите документ."), ""
|
| 505 |
|
| 506 |
chunks: List[str] = state["chunks"]
|
| 507 |
emb: np.ndarray = state["embeddings"]
|
| 508 |
|
| 509 |
top_idx, best_sim = retrieve_topk(q, emb, top_k=4)
|
| 510 |
if best_sim < RETRIEVE_MIN_SIM:
|
| 511 |
+
return _append_messages(chat_history, q, "В документе нет информации для ответа на этот вопрос."), ""
|
| 512 |
|
| 513 |
ctx_idx = []
|
| 514 |
for i in top_idx:
|
|
|
|
| 534 |
a = "В документе нет информации для ответа на этот вопрос."
|
| 535 |
|
| 536 |
cleanup_memory()
|
| 537 |
+
return _append_messages(chat_history, q, a), ""
|
| 538 |
|
| 539 |
|
| 540 |
def clear_chat():
|
|
|
|
| 645 |
q_btn.click(generate_questions, inputs=[diff, n_q, state], outputs=[q_out])
|
| 646 |
|
| 647 |
with gr.Tab("Чат"):
|
| 648 |
+
chat = gr.Chatbot(label="Чат", type="messages")
|
| 649 |
msg = gr.Textbox(lines=2, label="Вопрос")
|
| 650 |
send = gr.Button("Отправить")
|
| 651 |
clear = gr.Button("Очистить")
|
|
|
|
| 654 |
clear.click(clear_chat, inputs=None, outputs=[chat, msg])
|
| 655 |
|
| 656 |
|
| 657 |
+
def _launch_compat(app: gr.Blocks):
|
| 658 |
+
"""
|
| 659 |
+
Совместимо с разными версиями gradio:
|
| 660 |
+
- где есть concurrency_count
|
| 661 |
+
- где есть concurrency_limit / default_concurrency_limit
|
| 662 |
+
- где queue() без этих параметров
|
| 663 |
+
"""
|
| 664 |
+
q_params = inspect.signature(app.queue).parameters
|
| 665 |
+
kwargs = {}
|
| 666 |
+
|
| 667 |
+
if "max_size" in q_params:
|
| 668 |
+
kwargs["max_size"] = 16
|
| 669 |
+
|
| 670 |
+
# разные версии gradio используют разные имена
|
| 671 |
+
if "concurrency_count" in q_params:
|
| 672 |
+
kwargs["concurrency_count"] = 1
|
| 673 |
+
elif "concurrency_limit" in q_params:
|
| 674 |
+
kwargs["concurrency_limit"] = 1
|
| 675 |
+
elif "default_concurrency_limit" in q_params:
|
| 676 |
+
kwargs["default_concurrency_limit"] = 1
|
| 677 |
+
|
| 678 |
+
try:
|
| 679 |
+
app.queue(**kwargs).launch()
|
| 680 |
+
except TypeError:
|
| 681 |
+
# если queue() совсем другой — просто launch()
|
| 682 |
+
app.launch()
|
| 683 |
+
|
| 684 |
+
|
| 685 |
if __name__ == "__main__":
|
| 686 |
+
_launch_compat(demo)
|
|
|