Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
| 4 |
-
from typing import List, Dict, Tuple, Iterable, Optional, Any
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
|
@@ -53,7 +53,8 @@ def _load_encoder(name: str):
|
|
| 53 |
|
| 54 |
|
| 55 |
def _load_seq2seq(name: str):
|
| 56 |
-
|
|
|
|
| 57 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 58 |
name,
|
| 59 |
torch_dtype=DTYPE,
|
|
@@ -67,6 +68,7 @@ print("Loading models...")
|
|
| 67 |
emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
|
| 68 |
gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
|
| 69 |
|
|
|
|
| 70 |
fb_tok = None
|
| 71 |
fb_model = None
|
| 72 |
print("Models loaded.")
|
|
@@ -224,7 +226,7 @@ def retrieve_topk(query: str, embeddings_f16: np.ndarray, top_k: int = 4) -> Tup
|
|
| 224 |
|
| 225 |
|
| 226 |
# =======================
|
| 227 |
-
#
|
| 228 |
# =======================
|
| 229 |
BANNED = [
|
| 230 |
"контекст", "вопрос:", "ответ:", "правила", "требования",
|
|
@@ -265,26 +267,64 @@ def looks_bad(text: str) -> bool:
|
|
| 265 |
return False
|
| 266 |
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
@torch.inference_mode()
|
| 269 |
def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, max_input_tokens: int = 512) -> str:
|
| 270 |
prompt = (prompt or "").strip()
|
| 271 |
if not prompt:
|
| 272 |
return ""
|
|
|
|
| 273 |
batch = tokenizer(
|
| 274 |
prompt,
|
| 275 |
return_tensors="pt",
|
| 276 |
truncation=True,
|
| 277 |
max_length=max_input_tokens,
|
| 278 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()
|
| 289 |
|
| 290 |
|
|
@@ -302,7 +342,7 @@ def generate_clean(primary_prompt: str, fallback_prompt: str) -> str:
|
|
| 302 |
|
| 303 |
|
| 304 |
# =======================
|
| 305 |
-
# КОНСПЕКТ
|
| 306 |
# =======================
|
| 307 |
def summarize_part(part_1based: int, state: Dict[str, Any]) -> Tuple[str, str, str]:
|
| 308 |
chunks: List[str] = state.get("chunks", [])
|
|
@@ -601,4 +641,5 @@ with gr.Blocks(title="EduMultiSpace") as demo:
|
|
| 601 |
|
| 602 |
|
| 603 |
if __name__ == "__main__":
|
| 604 |
-
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
| 4 |
+
from typing import List, Dict, Tuple, Iterable, Optional, Any
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
def _load_seq2seq(name: str):
|
| 56 |
+
# критично для T5/MT5: use_fast=False
|
| 57 |
+
tok = AutoTokenizer.from_pretrained(name, use_fast=False)
|
| 58 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 59 |
name,
|
| 60 |
torch_dtype=DTYPE,
|
|
|
|
| 68 |
emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
|
| 69 |
gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
|
| 70 |
|
| 71 |
+
# fallback лениво (экономия памяти)
|
| 72 |
fb_tok = None
|
| 73 |
fb_model = None
|
| 74 |
print("Models loaded.")
|
|
|
|
| 226 |
|
| 227 |
|
| 228 |
# =======================
|
| 229 |
+
# ГЕНЕРАЦИЯ: защита от OOR
|
| 230 |
# =======================
|
| 231 |
BANNED = [
|
| 232 |
"контекст", "вопрос:", "ответ:", "правила", "требования",
|
|
|
|
| 267 |
return False
|
| 268 |
|
| 269 |
|
| 270 |
+
def _ensure_embeddings_size(tokenizer, model, required_size: int):
|
| 271 |
+
emb = model.get_input_embeddings()
|
| 272 |
+
cur = int(emb.num_embeddings)
|
| 273 |
+
if required_size > cur:
|
| 274 |
+
model.resize_token_embeddings(required_size)
|
| 275 |
+
# на всякий случай вернём на нужное устройство после ресайза
|
| 276 |
+
model.to(DEVICE)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
@torch.inference_mode()
|
| 280 |
def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, max_input_tokens: int = 512) -> str:
|
| 281 |
prompt = (prompt or "").strip()
|
| 282 |
if not prompt:
|
| 283 |
return ""
|
| 284 |
+
|
| 285 |
batch = tokenizer(
|
| 286 |
prompt,
|
| 287 |
return_tensors="pt",
|
| 288 |
truncation=True,
|
| 289 |
max_length=max_input_tokens,
|
| 290 |
)
|
| 291 |
+
|
| 292 |
+
# ВАЖНО: before .to(DEVICE) можно посчитать max_id на CPU
|
| 293 |
+
input_ids = batch["input_ids"]
|
| 294 |
+
max_id = int(input_ids.max().item()) if input_ids.numel() else 0
|
| 295 |
+
needed = max(int(len(tokenizer)), max_id + 1)
|
| 296 |
+
|
| 297 |
+
_ensure_embeddings_size(tokenizer, model, needed)
|
| 298 |
+
|
| 299 |
+
# После возможного resize — переносим на устройство
|
| 300 |
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
| 301 |
+
|
| 302 |
+
# Доп. страховка: если по какой-то причине всё ещё OOR — зажмём
|
| 303 |
+
emb_size = int(model.get_input_embeddings().num_embeddings)
|
| 304 |
+
if int(batch["input_ids"].max().item()) >= emb_size:
|
| 305 |
+
batch["input_ids"] = batch["input_ids"].clamp_max(emb_size - 1)
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
out_ids = model.generate(
|
| 309 |
+
**batch,
|
| 310 |
+
max_new_tokens=max_new_tokens,
|
| 311 |
+
num_beams=4,
|
| 312 |
+
do_sample=False,
|
| 313 |
+
no_repeat_ngram_size=3,
|
| 314 |
+
early_stopping=True,
|
| 315 |
+
)
|
| 316 |
+
except IndexError:
|
| 317 |
+
# retry: синхронизируем по len(tokenizer) и повторяем
|
| 318 |
+
_ensure_embeddings_size(tokenizer, model, int(len(tokenizer)))
|
| 319 |
+
out_ids = model.generate(
|
| 320 |
+
**batch,
|
| 321 |
+
max_new_tokens=max_new_tokens,
|
| 322 |
+
num_beams=4,
|
| 323 |
+
do_sample=False,
|
| 324 |
+
no_repeat_ngram_size=3,
|
| 325 |
+
early_stopping=True,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()
|
| 329 |
|
| 330 |
|
|
|
|
| 342 |
|
| 343 |
|
| 344 |
# =======================
|
| 345 |
+
# КОНСПЕКТ
|
| 346 |
# =======================
|
| 347 |
def summarize_part(part_1based: int, state: Dict[str, Any]) -> Tuple[str, str, str]:
|
| 348 |
chunks: List[str] = state.get("chunks", [])
|
|
|
|
| 641 |
|
| 642 |
|
| 643 |
if __name__ == "__main__":
|
| 644 |
+
# чтобы не было параллельных генераций, которые могут раздувать память на Spaces
|
| 645 |
+
demo.queue(concurrency_count=1, max_size=16).launch()
|