UnMelow commited on
Commit
a264b9a
·
verified ·
1 Parent(s): c2f90f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -16
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import re
3
  import gc
 
4
  from typing import List, Dict, Tuple, Iterable, Optional, Any
5
 
6
  import gradio as gr
@@ -53,7 +54,7 @@ def _load_encoder(name: str):
53
 
54
 
55
  def _load_seq2seq(name: str):
56
- # критично для T5/MT5: use_fast=False
57
  tok = AutoTokenizer.from_pretrained(name, use_fast=False)
58
  model = AutoModelForSeq2SeqLM.from_pretrained(
59
  name,
@@ -68,7 +69,6 @@ print("Loading models...")
68
  emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
69
  gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
70
 
71
- # fallback лениво (экономия памяти)
72
  fb_tok = None
73
  fb_model = None
74
  print("Models loaded.")
@@ -272,7 +272,6 @@ def _ensure_embeddings_size(tokenizer, model, required_size: int):
272
  cur = int(emb.num_embeddings)
273
  if required_size > cur:
274
  model.resize_token_embeddings(required_size)
275
- # на всякий случай вернём на нужное устройство после ресайза
276
  model.to(DEVICE)
277
 
278
 
@@ -289,17 +288,13 @@ def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, m
289
  max_length=max_input_tokens,
290
  )
291
 
292
- # ВАЖНО: before .to(DEVICE) можно посчитать max_id на CPU
293
  input_ids = batch["input_ids"]
294
  max_id = int(input_ids.max().item()) if input_ids.numel() else 0
295
  needed = max(int(len(tokenizer)), max_id + 1)
296
-
297
  _ensure_embeddings_size(tokenizer, model, needed)
298
 
299
- # После возможного resize — переносим на устройство
300
  batch = {k: v.to(DEVICE) for k, v in batch.items()}
301
 
302
- # Доп. страховка: если по какой-то причине всё ещё OOR — зажмём
303
  emb_size = int(model.get_input_embeddings().num_embeddings)
304
  if int(batch["input_ids"].max().item()) >= emb_size:
305
  batch["input_ids"] = batch["input_ids"].clamp_max(emb_size - 1)
@@ -314,7 +309,6 @@ def seq2seq_generate(tokenizer, model, prompt: str, max_new_tokens: int = 220, m
314
  early_stopping=True,
315
  )
316
  except IndexError:
317
- # retry: синхронизируем по len(tokenizer) и повторяем
318
  _ensure_embeddings_size(tokenizer, model, int(len(tokenizer)))
319
  out_ids = model.generate(
320
  **batch,
@@ -479,22 +473,42 @@ def generate_questions(difficulty: str, num_q: int, state: Dict[str, Any]) -> st
479
 
480
 
481
  # =======================
482
- # ЧАТ
483
  # =======================
484
- def chat_answer(message: str, chat_history: List[Tuple[str, str]], state: Dict[str, Any]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  q = (message or "").strip()
486
  if not q:
487
  return chat_history, ""
488
 
489
  if not state or not state.get("chunks") or state.get("embeddings") is None:
490
- return chat_history + [(q, "Сначала загрузите документ.")], ""
491
 
492
  chunks: List[str] = state["chunks"]
493
  emb: np.ndarray = state["embeddings"]
494
 
495
  top_idx, best_sim = retrieve_topk(q, emb, top_k=4)
496
  if best_sim < RETRIEVE_MIN_SIM:
497
- return chat_history + [(q, "В документе нет информации для ответа на этот вопрос.")], ""
498
 
499
  ctx_idx = []
500
  for i in top_idx:
@@ -520,7 +534,7 @@ def chat_answer(message: str, chat_history: List[Tuple[str, str]], state: Dict[s
520
  a = "В документе нет информации для ответа на этот вопрос."
521
 
522
  cleanup_memory()
523
- return chat_history + [(q, a)], ""
524
 
525
 
526
  def clear_chat():
@@ -631,7 +645,7 @@ with gr.Blocks(title="EduMultiSpace") as demo:
631
  q_btn.click(generate_questions, inputs=[diff, n_q, state], outputs=[q_out])
632
 
633
  with gr.Tab("Чат"):
634
- chat = gr.Chatbot(label="Чат")
635
  msg = gr.Textbox(lines=2, label="Вопрос")
636
  send = gr.Button("Отправить")
637
  clear = gr.Button("Очистить")
@@ -640,6 +654,33 @@ with gr.Blocks(title="EduMultiSpace") as demo:
640
  clear.click(clear_chat, inputs=None, outputs=[chat, msg])
641
 
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  if __name__ == "__main__":
644
- # чтобы не было параллельных генераций, которые могут раздувать память на Spaces
645
- demo.queue(concurrency_count=1, max_size=16).launch()
 
1
  import os
2
  import re
3
  import gc
4
+ import inspect
5
  from typing import List, Dict, Tuple, Iterable, Optional, Any
6
 
7
  import gradio as gr
 
54
 
55
 
56
  def _load_seq2seq(name: str):
57
+ # критично для T5/MT5: use_fast=False (SentencePiece)
58
  tok = AutoTokenizer.from_pretrained(name, use_fast=False)
59
  model = AutoModelForSeq2SeqLM.from_pretrained(
60
  name,
 
69
  emb_tok, emb_model = _load_encoder(EMB_MODEL_NAME)
70
  gen_tok, gen_model = _load_seq2seq(PRIMARY_GEN_MODEL)
71
 
 
72
  fb_tok = None
73
  fb_model = None
74
  print("Models loaded.")
 
272
  cur = int(emb.num_embeddings)
273
  if required_size > cur:
274
  model.resize_token_embeddings(required_size)
 
275
  model.to(DEVICE)
276
 
277
 
 
288
  max_length=max_input_tokens,
289
  )
290
 
 
291
  input_ids = batch["input_ids"]
292
  max_id = int(input_ids.max().item()) if input_ids.numel() else 0
293
  needed = max(int(len(tokenizer)), max_id + 1)
 
294
  _ensure_embeddings_size(tokenizer, model, needed)
295
 
 
296
  batch = {k: v.to(DEVICE) for k, v in batch.items()}
297
 
 
298
  emb_size = int(model.get_input_embeddings().num_embeddings)
299
  if int(batch["input_ids"].max().item()) >= emb_size:
300
  batch["input_ids"] = batch["input_ids"].clamp_max(emb_size - 1)
 
309
  early_stopping=True,
310
  )
311
  except IndexError:
 
312
  _ensure_embeddings_size(tokenizer, model, int(len(tokenizer)))
313
  out_ids = model.generate(
314
  **batch,
 
473
 
474
 
475
  # =======================
476
+ # ЧАТ (messages)
477
  # =======================
478
+ def _append_messages(history: Any, user_text: str, assistant_text: str) -> List[Dict[str, str]]:
479
+ if not history:
480
+ history = []
481
+ # если вдруг пришли tuples — конвертируем
482
+ if isinstance(history, list) and history and isinstance(history[0], (tuple, list)) and len(history[0]) == 2:
483
+ msgs: List[Dict[str, str]] = []
484
+ for u, a in history:
485
+ msgs.append({"role": "user", "content": str(u)})
486
+ msgs.append({"role": "assistant", "content": str(a)})
487
+ history = msgs
488
+ # если уже messages
489
+ if isinstance(history, list) and (not history or isinstance(history[0], dict)):
490
+ history = list(history)
491
+ history.append({"role": "user", "content": user_text})
492
+ history.append({"role": "assistant", "content": assistant_text})
493
+ return history
494
+ # fallback
495
+ return [{"role": "user", "content": user_text}, {"role": "assistant", "content": assistant_text}]
496
+
497
+
498
+ def chat_answer(message: str, chat_history: List[Dict[str, str]], state: Dict[str, Any]):
499
  q = (message or "").strip()
500
  if not q:
501
  return chat_history, ""
502
 
503
  if not state or not state.get("chunks") or state.get("embeddings") is None:
504
+ return _append_messages(chat_history, q, "Сначала загрузите документ."), ""
505
 
506
  chunks: List[str] = state["chunks"]
507
  emb: np.ndarray = state["embeddings"]
508
 
509
  top_idx, best_sim = retrieve_topk(q, emb, top_k=4)
510
  if best_sim < RETRIEVE_MIN_SIM:
511
+ return _append_messages(chat_history, q, "В документе нет информации для ответа на этот вопрос."), ""
512
 
513
  ctx_idx = []
514
  for i in top_idx:
 
534
  a = "В документе нет информации для ответа на этот вопрос."
535
 
536
  cleanup_memory()
537
+ return _append_messages(chat_history, q, a), ""
538
 
539
 
540
  def clear_chat():
 
645
  q_btn.click(generate_questions, inputs=[diff, n_q, state], outputs=[q_out])
646
 
647
  with gr.Tab("Чат"):
648
+ chat = gr.Chatbot(label="Чат", type="messages")
649
  msg = gr.Textbox(lines=2, label="Вопрос")
650
  send = gr.Button("Отправить")
651
  clear = gr.Button("Очистить")
 
654
  clear.click(clear_chat, inputs=None, outputs=[chat, msg])
655
 
656
 
657
+ def _launch_compat(app: gr.Blocks):
658
+ """
659
+ Совместимо с разными версиями gradio:
660
+ - где есть concurrency_count
661
+ - где есть concurrency_limit / default_concurrency_limit
662
+ - где queue() без этих параметров
663
+ """
664
+ q_params = inspect.signature(app.queue).parameters
665
+ kwargs = {}
666
+
667
+ if "max_size" in q_params:
668
+ kwargs["max_size"] = 16
669
+
670
+ # разные версии gradio используют разные имена
671
+ if "concurrency_count" in q_params:
672
+ kwargs["concurrency_count"] = 1
673
+ elif "concurrency_limit" in q_params:
674
+ kwargs["concurrency_limit"] = 1
675
+ elif "default_concurrency_limit" in q_params:
676
+ kwargs["default_concurrency_limit"] = 1
677
+
678
+ try:
679
+ app.queue(**kwargs).launch()
680
+ except TypeError:
681
+ # если queue() совсем другой — просто launch()
682
+ app.launch()
683
+
684
+
685
  if __name__ == "__main__":
686
+ _launch_compat(demo)