UnMelow commited on
Commit
c2f90f5
·
verified ·
1 Parent(s): abb4539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -13
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import re
3
  import gc
4
- from typing import List, Dict, Tuple, Iterable, Optional, Any # <-- FIX: добавили Any
5
 
6
  import gradio as gr
7
  import numpy as np
@@ -53,7 +53,8 @@ def _load_encoder(name: str):
53
 
54
 
55
  def _load_seq2seq(name: str):
56
- tok = AutoTokenizer.from_pretrained(name, use_fast=False) # важно для T5/MT5
 
57
  model = AutoModelForSeq2SeqLM.from_pretrained(
58
  name,
59
  torch_dtype=DTYPE,
@@ -67,6 +68,7 @@ print("Loading models...")
67
  emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
68
  gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
69
 
 
70
  fb_tok = None
71
  fb_model = None
72
  print("Models loaded.")
@@ -224,7 +226,7 @@ def retrieve_topk(query: str, embeddings_f16: np.ndarray, top_k: int = 4) -> Tup
224
 
225
 
226
  # =======================
227
- # ГЕНЕРАЦИЯ + САНИТИЗАЦИЯ
228
  # =======================
229
  BANNED = [
230
  "контекст", "вопрос:", "ответ:", "правила", "требования",
@@ -265,26 +267,64 @@ def looks_bad(text: str) -> bool:
265
  return False
266
 
267
 
 
 
 
 
 
 
 
 
 
268
  @torch.inference_mode()
269
  def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, max_input_tokens: int = 512) -> str:
270
  prompt = (prompt or "").strip()
271
  if not prompt:
272
  return ""
 
273
  batch = tokenizer(
274
  prompt,
275
  return_tensors="pt",
276
  truncation=True,
277
  max_length=max_input_tokens,
278
  )
 
 
 
 
 
 
 
 
 
279
  batch = {k: v.to(DEVICE) for k, v in batch.items()}
280
- out_ids = model.generate(
281
- **batch,
282
- max_new_tokens=max_new_tokens,
283
- num_beams=4,
284
- do_sample=False,
285
- no_repeat_ngram_size=3,
286
- early_stopping=True,
287
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()
289
 
290
 
@@ -302,7 +342,7 @@ def generate_clean(primary_prompt: str, fallback_prompt: str) -> str:
302
 
303
 
304
  # =======================
305
- # КОНСПЕКТ ПО ЧАСТИ
306
  # =======================
307
  def summarize_part(part_1based: int, state: Dict[str, Any]) -> Tuple[str, str, str]:
308
  chunks: List[str] = state.get("chunks", [])
@@ -601,4 +641,5 @@ with gr.Blocks(title="EduMultiSpace") as demo:
601
 
602
 
603
  if __name__ == "__main__":
604
- demo.launch()
 
 
1
  import os
2
  import re
3
  import gc
4
+ from typing import List, Dict, Tuple, Iterable, Optional, Any
5
 
6
  import gradio as gr
7
  import numpy as np
 
53
 
54
 
55
  def _load_seq2seq(name: str):
56
+ # критично для T5/MT5: use_fast=False
57
+ tok = AutoTokenizer.from_pretrained(name, use_fast=False)
58
  model = AutoModelForSeq2SeqLM.from_pretrained(
59
  name,
60
  torch_dtype=DTYPE,
 
68
  emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
69
  gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
70
 
71
+ # fallback лениво (экономия памяти)
72
  fb_tok = None
73
  fb_model = None
74
  print("Models loaded.")
 
226
 
227
 
228
  # =======================
229
+ # ГЕНЕРАЦИЯ: защита от OOR
230
  # =======================
231
  BANNED = [
232
  "контекст", "вопрос:", "ответ:", "правила", "требования",
 
267
  return False
268
 
269
 
270
+ def _ensure_embeddings_size(tokenizer, model, required_size: int):
271
+ emb = model.get_input_embeddings()
272
+ cur = int(emb.num_embeddings)
273
+ if required_size > cur:
274
+ model.resize_token_embeddings(required_size)
275
+ # на всякий случай вернём на нужное устройство после ресайза
276
+ model.to(DEVICE)
277
+
278
+
279
  @torch.inference_mode()
280
  def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, max_input_tokens: int = 512) -> str:
281
  prompt = (prompt or "").strip()
282
  if not prompt:
283
  return ""
284
+
285
  batch = tokenizer(
286
  prompt,
287
  return_tensors="pt",
288
  truncation=True,
289
  max_length=max_input_tokens,
290
  )
291
+
292
+ # ВАЖНО: before .to(DEVICE) можно посчитать max_id на CPU
293
+ input_ids = batch["input_ids"]
294
+ max_id = int(input_ids.max().item()) if input_ids.numel() else 0
295
+ needed = max(int(len(tokenizer)), max_id + 1)
296
+
297
+ _ensure_embeddings_size(tokenizer, model, needed)
298
+
299
+ # После возможного resize — переносим на устройство
300
  batch = {k: v.to(DEVICE) for k, v in batch.items()}
301
+
302
+ # Доп. страховка: если по какой-то причине всё ещё OOR — зажмём
303
+ emb_size = int(model.get_input_embeddings().num_embeddings)
304
+ if int(batch["input_ids"].max().item()) >= emb_size:
305
+ batch["input_ids"] = batch["input_ids"].clamp_max(emb_size - 1)
306
+
307
+ try:
308
+ out_ids = model.generate(
309
+ **batch,
310
+ max_new_tokens=max_new_tokens,
311
+ num_beams=4,
312
+ do_sample=False,
313
+ no_repeat_ngram_size=3,
314
+ early_stopping=True,
315
+ )
316
+ except IndexError:
317
+ # retry: синхронизируем по len(tokenizer) и повторяем
318
+ _ensure_embeddings_size(tokenizer, model, int(len(tokenizer)))
319
+ out_ids = model.generate(
320
+ **batch,
321
+ max_new_tokens=max_new_tokens,
322
+ num_beams=4,
323
+ do_sample=False,
324
+ no_repeat_ngram_size=3,
325
+ early_stopping=True,
326
+ )
327
+
328
  return tokenizer.decode(out_ids[0], skip_special_tokens=True).strip()
329
 
330
 
 
342
 
343
 
344
  # =======================
345
+ # КОНСПЕКТ
346
  # =======================
347
  def summarize_part(part_1based: int, state: Dict[str, Any]) -> Tuple[str, str, str]:
348
  chunks: List[str] = state.get("chunks", [])
 
641
 
642
 
643
  if __name__ == "__main__":
644
+ # чтобы не было параллельных генераций, которые могут раздувать память на Spaces
645
+ demo.queue(concurrency_count=1, max_size=16).launch()