UnMelow commited on
Commit
e22a543
·
verified ·
1 Parent(s): 5c95655

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +623 -615
app.py CHANGED
@@ -1,673 +1,681 @@
1
  import os
2
- import random
3
  import math
 
 
4
  from dataclasses import dataclass
5
- from typing import List, Tuple, Dict, Optional
6
 
7
- import gradio as gr
8
  import torch
9
- from PIL import Image, ImageDraw, ImageFont
10
 
 
11
  from transformers import (
12
  AutoTokenizer,
13
  AutoModel,
14
- AutoModelForSeq2SeqLM,
15
- AutoModelForCausalLM,
16
  )
 
17
 
18
- # ============================================================
19
- # CPU setup
20
- # ============================================================
 
21
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
 
 
 
 
22
  DEVICE = torch.device("cpu")
23
- torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
24
-
25
- # ============================================================
26
- # 3 Transformers (minimum)
27
- # 1) Coach (Seq2Seq)
28
- # 2) Opponent (Causal LM)
29
- # 3) Embeddings (Encoder)
30
- # ============================================================
31
- COACH_MODEL_NAME = os.getenv("COACH_MODEL", "google/flan-t5-small")
32
- OPP_MODEL_NAME = os.getenv("OPP_MODEL", "distilgpt2")
33
- EMB_MODEL_NAME = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-MiniLM-L3-v2")
34
-
35
- coach_tok = AutoTokenizer.from_pretrained(COACH_MODEL_NAME)
36
- coach_model = AutoModelForSeq2SeqLM.from_pretrained(COACH_MODEL_NAME).eval().to(DEVICE)
37
-
38
- opp_tok = AutoTokenizer.from_pretrained(OPP_MODEL_NAME)
39
- opp_model = AutoModelForCausalLM.from_pretrained(OPP_MODEL_NAME).eval().to(DEVICE)
40
-
41
- emb_tok = AutoTokenizer.from_pretrained(EMB_MODEL_NAME)
42
- emb_model = AutoModel.from_pretrained(EMB_MODEL_NAME).eval().to(DEVICE)
43
-
44
-
45
- # ============================================================
46
- # Checkers engine (English draughts-like)
47
- # Pieces:
48
- # '.' empty
49
- # 'w' white man (user)
50
- # 'W' white king
51
- # 'b' black man (bot)
52
- # 'B' black king
53
- #
54
- # Coordinates:
55
- # internal: r=0..7 top->bottom, c=0..7 left->right
56
- # dark squares: (r+c)%2==1
57
- # Move string:
58
- # "b6-a5" or "c3-e5-g7" using a-h and 1-8 (1 is bottom row).
59
- # ============================================================
60
-
61
- def inside(r: int, c: int) -> bool:
62
- return 0 <= r < 8 and 0 <= c < 8
63
-
64
- def is_dark(r: int, c: int) -> bool:
65
- return (r + c) % 2 == 1
66
-
67
- def rc_to_alg(r: int, c: int) -> str:
68
- # a1 bottom-left => internal (7,0)
69
- file_ = chr(ord("a") + c)
70
- rank = str(8 - r)
71
- return f"{file_}{rank}"
72
-
73
- def alg_to_rc(s: str) -> Tuple[int, int]:
74
- s = s.strip().lower()
75
- c = ord(s[0]) - ord("a")
76
- r = 8 - int(s[1])
77
- return r, c
78
-
79
- def move_seq_to_str(seq: List[Tuple[int, int]]) -> str:
80
- return "-".join(rc_to_alg(r, c) for r, c in seq)
81
-
82
- def move_str_to_seq(s: str) -> List[Tuple[int, int]]:
83
- parts = [p.strip() for p in s.split("-") if p.strip()]
84
- return [alg_to_rc(p) for p in parts]
85
-
86
- def piece_color(p: str) -> Optional[str]:
87
- if p in ("w", "W"):
88
- return "w"
89
- if p in ("b", "B"):
90
- return "b"
91
- return None
92
-
93
- def is_king(p: str) -> bool:
94
- return p in ("W", "B")
95
 
 
 
 
 
 
96
 
97
- @dataclass
98
- class GameState:
99
- board: List[List[str]]
100
- turn: str # "w" user, "b" bot
101
- history: List[str]
102
- last_analysis: str
103
-
104
-
105
- def initial_board() -> List[List[str]]:
106
- b = [["." for _ in range(8)] for _ in range(8)]
107
- # Black at top rows 0-2 on dark squares
108
- for r in range(0, 3):
109
- for c in range(8):
110
- if is_dark(r, c):
111
- b[r][c] = "b"
112
- # White at bottom rows 5-7 on dark squares
113
- for r in range(5, 8):
114
- for c in range(8):
115
- if is_dark(r, c):
116
- b[r][c] = "w"
117
- return b
118
-
119
- def clone_board(board: List[List[str]]) -> List[List[str]]:
120
- return [row[:] for row in board]
121
-
122
- def board_to_ascii(board: List[List[str]]) -> str:
123
- # compact representation for prompting
124
- lines = []
125
- for r in range(8):
126
- lines.append("".join(board[r]))
127
- return "\n".join(lines)
128
-
129
- def count_material(board: List[List[str]]) -> Dict[str, float]:
130
- score = {"w": 0.0, "b": 0.0}
131
- for r in range(8):
132
- for c in range(8):
133
- p = board[r][c]
134
- if p == "w":
135
- score["w"] += 1.0
136
- elif p == "W":
137
- score["w"] += 1.6
138
- elif p == "b":
139
- score["b"] += 1.0
140
- elif p == "B":
141
- score["b"] += 1.6
142
- return score
143
-
144
- def promote_if_needed(p: str, r: int) -> str:
145
- if p == "w" and r == 0:
146
- return "W"
147
- if p == "b" and r == 7:
148
- return "B"
149
- return p
150
-
151
-
152
- # ----------------------------
153
- # Move generation
154
- # ----------------------------
155
- def move_dirs(p: str) -> List[Tuple[int, int]]:
156
- # movement directions (step)
157
- if p == "w":
158
- return [(-1, -1), (-1, +1)]
159
- if p == "b":
160
- return [(+1, -1), (+1, +1)]
161
- # kings
162
- if p in ("W", "B"):
163
- return [(-1, -1), (-1, +1), (+1, -1), (+1, +1)]
164
- return []
165
-
166
- def capture_dirs(p: str) -> List[Tuple[int, int]]:
167
- # English draughts: men capture forward only; kings both ways
168
- return move_dirs(p)
169
-
170
- def gen_simple_moves(board: List[List[str]], color: str) -> List[List[Tuple[int, int]]]:
171
- moves = []
172
- for r in range(8):
173
- for c in range(8):
174
- p = board[r][c]
175
- if piece_color(p) != color:
176
- continue
177
- for dr, dc in move_dirs(p):
178
- r2, c2 = r + dr, c + dc
179
- if inside(r2, c2) and board[r2][c2] == ".":
180
- moves.append([(r, c), (r2, c2)])
181
- return moves
182
 
183
- def gen_captures_from(board: List[List[str]], r: int, c: int, p: str) -> List[List[Tuple[int, int]]]:
 
184
  """
185
- Returns capture sequences starting at (r,c), including start and landings.
186
- If man reaches king row during capture, we stop (promotion at end of move).
187
  """
188
- color = piece_color(p)
189
- assert color in ("w", "b")
 
 
 
 
 
190
 
191
- sequences = []
192
- found_any = False
193
 
194
- for dr, dc in capture_dirs(p):
195
- r_mid, c_mid = r + dr, c + dc
196
- r2, c2 = r + 2 * dr, c + 2 * dc
197
- if not (inside(r2, c2) and inside(r_mid, c_mid)):
198
- continue
199
- mid_piece = board[r_mid][c_mid]
200
- if mid_piece == ".":
201
- continue
202
- if piece_color(mid_piece) == color:
203
- continue
204
- if board[r2][c2] != ".":
205
- continue
206
 
207
- # perform capture on a cloned board
208
- nb = clone_board(board)
209
- nb[r][c] = "."
210
- nb[r_mid][c_mid] = "."
211
- nb[r2][c2] = p # promotion deferred
212
 
213
- # stop extending if this is a man that reaches king row
214
- if (p == "w" and r2 == 0) or (p == "b" and r2 == 7):
215
- sequences.append([(r, c), (r2, c2)])
216
- found_any = True
217
- continue
218
 
219
- tails = gen_captures_from(nb, r2, c2, p)
220
- if tails:
221
- for t in tails:
222
- sequences.append([(r, c)] + t[1:])
223
- found_any = True
224
- else:
225
- sequences.append([(r, c), (r2, c2)])
226
- found_any = True
227
 
228
- return sequences if found_any else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- def gen_legal_moves(board: List[List[str]], color: str) -> List[List[Tuple[int, int]]]:
231
- captures = []
232
- for r in range(8):
233
- for c in range(8):
234
- p = board[r][c]
235
- if piece_color(p) != color:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  continue
237
- caps = gen_captures_from(board, r, c, p)
238
- captures.extend(caps)
239
-
240
- # forced capture rule
241
- if captures:
242
- # remove duplicates (can arise via different recursion paths)
243
- uniq = {}
244
- for seq in captures:
245
- key = tuple(seq)
246
- uniq[key] = seq
247
- return list(uniq.values())
248
-
249
- return gen_simple_moves(board, color)
250
-
251
- def apply_move(board: List[List[str]], seq: List[Tuple[int, int]]) -> List[List[str]]:
252
- nb = clone_board(board)
253
- (r0, c0) = seq[0]
254
- p = nb[r0][c0]
255
- nb[r0][c0] = "."
256
-
257
- for i in range(1, len(seq)):
258
- (r1, c1) = seq[i - 1]
259
- (r2, c2) = seq[i]
260
- # capture if jump
261
- if abs(r2 - r1) == 2 and abs(c2 - c1) == 2:
262
- rm = (r1 + r2) // 2
263
- cm = (c1 + c2) // 2
264
- nb[rm][cm] = "."
265
-
266
- (rf, cf) = seq[-1]
267
- p2 = promote_if_needed(p, rf)
268
- nb[rf][cf] = p2
269
- return nb
270
-
271
- def winner(board: List[List[str]]) -> Optional[str]:
272
- # winner if opponent has no pieces or no moves
273
- w_cnt = 0
274
- b_cnt = 0
275
- for r in range(8):
276
- for c in range(8):
277
- if board[r][c] in ("w", "W"):
278
- w_cnt += 1
279
- elif board[r][c] in ("b", "B"):
280
- b_cnt += 1
281
- if w_cnt == 0:
282
- return "b"
283
- if b_cnt == 0:
284
- return "w"
285
- if not gen_legal_moves(board, "w"):
286
- return "b"
287
- if not gen_legal_moves(board, "b"):
288
- return "w"
289
- return None
290
-
291
-
292
- # ============================================================
293
- # Simple engine for analysis (not a transformer):
294
- # minimax on material + mobility, small depth for CPU.
295
- # ============================================================
296
- def eval_board(board: List[List[str]]) -> float:
297
- m = count_material(board)
298
- # positive => good for white
299
- score = (m["w"] - m["b"])
300
- # mobility bonus
301
- score += 0.04 * (len(gen_legal_moves(board, "w")) - len(gen_legal_moves(board, "b")))
302
- return score
303
-
304
- def minimax(board: List[List[str]], color: str, depth: int, alpha: float, beta: float) -> Tuple[float, Optional[List[Tuple[int, int]]]]:
305
- win = winner(board)
306
- if win == "w":
307
- return 10_000.0, None
308
- if win == "b":
309
- return -10_000.0, None
310
-
311
- if depth == 0:
312
- return eval_board(board), None
313
-
314
- moves = gen_legal_moves(board, color)
315
- if not moves:
316
- # no moves => lose
317
- return (-10_000.0 if color == "w" else 10_000.0), None
318
-
319
- best_move = None
320
-
321
- if color == "w":
322
- best = -math.inf
323
- for mv in moves:
324
- nb = apply_move(board, mv)
325
- val, _ = minimax(nb, "b", depth - 1, alpha, beta)
326
- if val > best:
327
- best = val
328
- best_move = mv
329
- alpha = max(alpha, best)
330
- if beta <= alpha:
331
  break
332
- return best, best_move
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  else:
334
- best = math.inf
335
- for mv in moves:
336
- nb = apply_move(board, mv)
337
- val, _ = minimax(nb, "w", depth - 1, alpha, beta)
338
- if val < best:
339
- best = val
340
- best_move = mv
341
- beta = min(beta, best)
342
- if beta <= alpha:
343
- break
344
- return best, best_move
345
-
346
-
347
- # ============================================================
348
- # Embeddings (transformer #3) for retrieving tips
349
- # ============================================================
350
- TIPS = [
351
- "Всегда проверяй обязательный бой: если есть взятие, обычный ход запре��ён.",
352
- "Старайся сохранять дамочную линию: не открывай край без причины.",
353
- "Не меняйся, если это приводит к потере темпа и отдаёт центр.",
354
- "Центр важен: контроль диагоналей увеличивает мобильность и шансы на многоходовые взятия.",
355
- "Перед ходом оцени ответ соперника: что он берёт или чем отвечает на диагонали?",
356
- "Если видишь возможность мультибоя, считай траекторию до конца — важно, где ты остановишься.",
357
- "Дамка сильнее: иногда стоит пожертвовать шашку ради прохода в дамки.",
358
- "Не оставляй одиночные шашки без поддержки — их легко поймать взятием.",
359
- "Думай про 'вилку' (двойную угрозу) и про то, чтобы не подставлять шашку под обязательный бой.",
360
- ]
361
 
362
- @torch.no_grad()
363
- def embed_text(text: str) -> torch.Tensor:
364
- toks = emb_tok(text, return_tensors="pt", truncation=True, max_length=128, padding=True)
365
- toks = {k: v.to(DEVICE) for k, v in toks.items()}
366
- out = emb_model(**toks)
367
- # mean pooling
368
- last = out.last_hidden_state # [B,T,H]
369
- mask = toks["attention_mask"].unsqueeze(-1) # [B,T,1]
370
- pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
371
- pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
372
- return pooled[0].cpu()
373
-
374
- TIP_EMBS = torch.stack([embed_text(t) for t in TIPS], dim=0) # [N,H]
375
-
376
- def retrieve_tips(query: str, k: int = 3) -> List[str]:
377
- q = embed_text(query)
378
- sims = (TIP_EMBS @ q.unsqueeze(1)).squeeze(1) # [N]
379
- top = torch.topk(sims, k=min(k, len(TIPS))).indices.tolist()
380
- return [TIPS[i] for i in top]
381
-
382
-
383
- # ============================================================
384
- # Coach (transformer #1): generates explanation/feedback
385
- # ============================================================
386
- @torch.no_grad()
387
- def coach_generate(prompt: str, max_new_tokens: int = 160) -> str:
388
- inp = coach_tok(prompt, return_tensors="pt", truncation=True, max_length=512)
389
- inp = {k: v.to(DEVICE) for k, v in inp.items()}
390
- out = coach_model.generate(
391
- **inp,
392
  max_new_tokens=max_new_tokens,
393
- do_sample=False,
394
- num_beams=1,
 
 
 
 
395
  )
396
- text = coach_tok.decode(out[0], skip_special_tokens=True)
397
- return text.strip()
398
-
399
-
400
- # ============================================================
401
- # Opponent (transformer #2): chooses a legal move
402
- # ============================================================
403
- @torch.no_grad()
404
- def opponent_choose_move(board: List[List[str]], legal_moves: List[str]) -> str:
405
- # distilgpt2 is not instruction-tuned, so we keep it extremely constrained and parse output.
406
- board_ascii = board_to_ascii(board)
407
- moves_block = "\n".join([f"- {m}" for m in legal_moves[:40]]) # cap list
408
- prompt = (
409
- "You are playing checkers as Black.\n"
410
- "Choose ONE move exactly from the list. Output only that move.\n"
411
- f"Board:\n{board_ascii}\n"
412
- f"Moves:\n{moves_block}\n"
413
- "Move:"
414
- )
415
- inp = opp_tok(prompt, return_tensors="pt", truncation=True, max_length=512)
416
- inp = {k: v.to(DEVICE) for k, v in inp.items()}
417
- gen = opp_model.generate(
418
- **inp,
419
- max_new_tokens=24,
420
- do_sample=True,
421
- top_p=0.85,
422
- temperature=0.7,
423
- pad_token_id=opp_tok.eos_token_id,
424
- )
425
- out = opp_tok.decode(gen[0], skip_special_tokens=True)
426
- tail = out.split("Move:")[-1].strip()
427
-
428
- # parse: pick the first legal move that appears in the generated tail
429
- for m in legal_moves:
430
- if m in tail:
431
- return m
432
-
433
- # fallback: try extract token pattern like a1-b2
434
- cand = re.findall(r"[a-h][1-8](?:-[a-h][1-8])+", tail.lower())
435
- if cand:
436
- for c in cand:
437
- if c in legal_moves:
438
- return c
439
-
440
- # final fallback: random legal
441
- return random.choice(legal_moves)
442
-
443
-
444
- # ============================================================
445
- # Rendering board
446
- # ============================================================
447
- def render_board(board: List[List[str]], size: int = 520) -> Image.Image:
448
- pad = 20
449
- cell = (size - 2 * pad) // 8
450
- img = Image.new("RGB", (size, size), (245, 245, 245))
451
- d = ImageDraw.Draw(img)
452
-
453
- dark = (150, 110, 80)
454
- light = (235, 220, 200)
455
-
456
- # grid
457
- for r in range(8):
458
- for c in range(8):
459
- x0 = pad + c * cell
460
- y0 = pad + r * cell
461
- x1 = x0 + cell
462
- y1 = y0 + cell
463
- d.rectangle([x0, y0, x1, y1], fill=(dark if is_dark(r, c) else light))
464
-
465
- # pieces
466
- for r in range(8):
467
- for c in range(8):
468
- p = board[r][c]
469
- if p == ".":
470
- continue
471
- cx = pad + c * cell + cell // 2
472
- cy = pad + r * cell + cell // 2
473
- rad = int(cell * 0.38)
474
 
475
- if p in ("w", "W"):
476
- fill = (245, 245, 245)
477
- outline = (30, 30, 30)
478
- else:
479
- fill = (40, 40, 40)
480
- outline = (230, 230, 230)
481
 
482
- d.ellipse([cx - rad, cy - rad, cx + rad, cy + rad], fill=fill, outline=outline, width=3)
483
 
484
- if is_king(p):
485
- # crown marker
486
- d.ellipse([cx - rad // 2, cy - rad // 2, cx + rad // 2, cy + rad // 2], outline=(255, 215, 0), width=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
488
- # coordinates
489
- try:
490
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
491
- except Exception:
492
- font = None
493
 
494
- for c in range(8):
495
- d.text((pad + c * cell + 3, pad + 8 * cell + 2), chr(ord("a") + c), fill=(30, 30, 30), font=font)
496
- for r in range(8):
497
- d.text((3, pad + r * cell + 3), str(8 - r), fill=(30, 30, 30), font=font)
498
 
499
- return img
 
 
500
 
 
 
 
 
501
 
502
- # ============================================================
503
- # Game logic wrapper
504
- # ============================================================
505
- def new_game() -> GameState:
506
- return GameState(
507
- board=initial_board(),
508
- turn="w",
509
- history=[],
510
- last_analysis="",
511
- )
512
 
513
- def legal_moves_str(board: List[List[str]], color: str) -> List[str]:
514
- moves = gen_legal_moves(board, color)
515
- ms = [move_seq_to_str(mv) for mv in moves]
516
- # stable ordering: captures first (longer sequences first), then lexicographic
517
- ms.sort(key=lambda s: (-s.count("-"), s))
518
- return ms
519
-
520
- def analyze_user_move(board_before: List[List[str]], user_move_str: str) -> str:
521
- # engine "best move" as baseline (not a transformer)
522
- depth = int(os.getenv("ANALYSIS_DEPTH", "3"))
523
- best_val, best_mv = minimax(board_before, "w", depth=depth, alpha=-math.inf, beta=math.inf)
524
- best_str = move_seq_to_str(best_mv) if best_mv else "(none)"
525
-
526
- tips = retrieve_tips("шашки: как улучшить ход и не подставиться", k=3)
527
-
528
- prompt = (
529
- "Ты тренер по шашкам. Коротко и по делу.\n"
530
- f"Ход игрока: {user_move_str}\n"
531
- f"Рекомендованный ход (по анализу): {best_str}\n"
532
- "Дай объяснение: почему рекомендованный лучше, и какая ошибка/риск в ходе игрока.\n"
533
- "Добавь 2-3 практических совета.\n"
534
- "Подсказки:\n"
535
- + "\n".join(f"- {t}" for t in tips)
536
- )
537
- return coach_generate(prompt, max_new_tokens=180)
538
-
539
-
540
- def step_user_and_bot(state: GameState, user_move: str) -> Tuple[GameState, str]:
541
- if winner(state.board) is not None:
542
- return state, "Game already finished."
543
-
544
- if state.turn != "w":
545
- return state, "Not your turn."
546
-
547
- leg = legal_moves_str(state.board, "w")
548
- if user_move not in leg:
549
- return state, "Invalid move (not in legal list)."
550
-
551
- board_before = clone_board(state.board)
552
- seq = move_str_to_seq(user_move)
553
- state.board = apply_move(state.board, seq)
554
- state.history.append(f"White: {user_move}")
555
- state.turn = "b"
556
-
557
- # analysis (coach transformer)
558
- state.last_analysis = analyze_user_move(board_before, user_move)
559
-
560
- win = winner(state.board)
561
- if win is not None:
562
- state.history.append("Result: " + ("White wins" if win == "w" else "Black wins"))
563
- return state, ("White wins." if win == "w" else "Black wins.")
564
-
565
- # bot move
566
- bot_leg = legal_moves_str(state.board, "b")
567
- if not bot_leg:
568
- state.history.append("Result: White wins")
569
- return state, "White wins."
570
-
571
- bot_move = opponent_choose_move(state.board, bot_leg)
572
- bot_seq = move_str_to_seq(bot_move)
573
- state.board = apply_move(state.board, bot_seq)
574
- state.history.append(f"Black: {bot_move}")
575
- state.turn = "w"
576
-
577
- win = winner(state.board)
578
- if win is not None:
579
- state.history.append("Result: " + ("White wins" if win == "w" else "Black wins"))
580
- return state, ("White wins." if win == "w" else "Black wins.")
581
-
582
- return state, f"Bot played: {bot_move}"
583
-
584
-
585
- # ============================================================
586
- # Coach chat (transformer #1 + embeddings #3)
587
- # ============================================================
588
- def coach_chat(state: GameState, message: str, chat_hist: List[Tuple[str, str]]):
589
- msg = (message or "").strip()
590
- if not msg:
591
- return chat_hist, ""
592
-
593
- # Retrieve tips relevant to the question
594
- tips = retrieve_tips(msg, k=3)
595
-
596
- # Provide board context
597
- context = board_to_ascii(state.board)
598
- last = state.history[-6:] if state.history else []
599
-
600
- prompt = (
601
- "Ты тренер по шашкам. Отвечай кратко, но конкретно.\n"
602
- f"Вопрос игрока: {msg}\n"
603
- "Контекст партии (последние ходы):\n"
604
- + ("\n".join(last) if last else "(нет)")
605
- + "\n"
606
- "Доска (ASCII):\n"
607
- + context
608
- + "\n"
609
- "Полезные подсказки:\n"
610
- + "\n".join(f"- {t}" for t in tips)
611
- + "\n"
612
- "Ответ:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
  )
614
 
615
- answer = coach_generate(prompt, max_new_tokens=180)
616
- chat_hist = chat_hist + [(msg, answer)]
617
- return chat_hist, ""
618
 
 
 
 
 
619
 
620
- # ============================================================
621
- # UI
622
- # ============================================================
623
- theme = gr.themes.Monochrome(font=[gr.themes.GoogleFont("Inter"), "system-ui"])
624
 
625
- with gr.Blocks(theme=theme, title="Checkers Coach (CPU, 3 Transformers)") as demo:
626
- state = gr.State(new_game())
627
 
628
- with gr.Row():
629
- with gr.Column(scale=1, min_width=360):
630
- board_img = gr.Image(label="Board", type="pil", height=520)
631
- status = gr.Textbox(label="Status", value="", interactive=False)
632
 
633
- move_dd = gr.Dropdown(label="Your move (White)", choices=[], value=None)
634
- play_btn = gr.Button("Play move", variant="primary")
635
- new_btn = gr.Button("New game")
 
636
 
637
- analysis = gr.Textbox(label="Coach analysis", lines=10, interactive=False)
 
 
638
 
639
- with gr.Column(scale=1, min_width=360):
640
- hist = gr.Markdown("")
641
- gr.Markdown("### Coach chat")
642
- chat = gr.Chatbot(height=360)
643
- msg = gr.Textbox(label="Message", placeholder="Ask about strategy, mistakes, next plan…")
644
- send = gr.Button("Send")
645
 
646
- def refresh_ui(gs: GameState):
647
- img = render_board(gs.board)
648
- leg = legal_moves_str(gs.board, "w") if winner(gs.board) is None else []
649
- h = "### History\n" + ("\n".join([f"- {x}" for x in gs.history]) if gs.history else "- (empty)")
650
- return img, ("" if gs.turn == "w" else "Bot thinking / waiting…"), gr.update(choices=leg, value=(leg[0] if leg else None)), gs.last_analysis, h
651
 
652
- def on_new():
653
- gs = new_game()
654
- return (gs, ) + refresh_ui(gs) + ([], "")
655
 
656
- def on_play(gs: GameState, mv: str):
657
- gs, st = step_user_and_bot(gs, mv or "")
658
- img, _, dd, an, h = refresh_ui(gs)
659
- return gs, img, st, dd, an, h
 
 
 
 
660
 
661
- def on_send(gs: GameState, m: str, ch: List[Tuple[str, str]]):
662
- ch, cleared = coach_chat(gs, m, ch or [])
663
- return ch, cleared
664
 
665
- demo.load(lambda gs: refresh_ui(gs), inputs=[state], outputs=[board_img, status, move_dd, analysis, hist])
 
666
 
667
- new_btn.click(on_new, inputs=[], outputs=[state, board_img, status, move_dd, analysis, hist, chat, msg])
668
- play_btn.click(on_play, inputs=[state, move_dd], outputs=[state, board_img, status, move_dd, analysis, hist])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
- send.click(on_send, inputs=[state, msg, chat], outputs=[chat, msg])
671
 
672
  if __name__ == "__main__":
673
- demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
1
  import os
2
+ import re
3
  import math
4
+ import time
5
+ import threading
6
  from dataclasses import dataclass
7
+ from typing import Dict, List, Tuple, Optional, Any
8
 
9
+ import numpy as np
10
  import torch
11
+ import gradio as gr
12
 
13
+ from huggingface_hub import HfApi
14
  from transformers import (
15
  AutoTokenizer,
16
  AutoModel,
17
+ AutoModelForQuestionAnswering,
18
+ T5ForConditionalGeneration,
19
  )
20
+ from transformers.utils import logging as hf_logging
21
 
22
+
23
+ # ---------------------------
24
+ # Runtime / logging hygiene
25
+ # ---------------------------
26
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
27
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
28
+ os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
29
+
30
+ hf_logging.set_verbosity_error()
31
+
32
  DEVICE = torch.device("cpu")
33
+ torch.set_grad_enabled(False)
34
+
35
+ # Hard safety limits (RAM + CPU time)
36
+ MAX_INPUT_CHARS = 60_000 # user text max
37
+ MAX_CHUNKS = 100 # max chunks to index
38
+ CHUNK_TARGET_CHARS = 900 # chunk target size
39
+ EMBED_BATCH = 16 # embedding batch size
40
+ GEN_MAX_NEW_TOKENS = 200 # generation cap
41
+ QA_MAX_LENGTH = 384 # QA tokens
42
+ QA_STRIDE = 128 # QA stride for long contexts
43
+ MAX_CONTEXT_CHARS = 3_500 # context cap before QA
44
+
45
+
46
+ # ---------------------------
47
+ # Model selection (availability + fallback)
48
+ # ---------------------------
49
+ GEN_CANDIDATES = [
50
+ "cointegrated/rut5-base-multitask",
51
+ "cointegrated/rut5-small",
52
+ "google/flan-t5-small",
53
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ EMB_CANDIDATES = [
56
+ "intfloat/multilingual-e5-small",
57
+ "intfloat/e5-small-v2",
58
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
59
+ ]
60
 
61
+ QA_CANDIDATES = [
62
+ "mrm8488/bert-multi-cased-finetuned-xquadv1",
63
+ "timopixel/bert-base-multilingual-cased-finetuned-squad",
64
+ "distilbert-base-cased-distilled-squad", # english fallback
65
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+
68
+ def _hf_exists(model_id: str) -> bool:
69
  """
70
+ Best-effort online check. If offline/blocked, return True (we'll try to load).
 
71
  """
72
+ try:
73
+ api = HfApi()
74
+ api.model_info(model_id)
75
+ return True
76
+ except Exception:
77
+ # If we cannot check (offline), do not fail early.
78
+ return True
79
 
 
 
80
 
81
+ def pick_first_available(candidates: List[str]) -> str:
82
+ for mid in candidates:
83
+ if _hf_exists(mid):
84
+ return mid
85
+ return candidates[0]
 
 
 
 
 
 
 
86
 
 
 
 
 
 
87
 
88
+ SELECTED_GEN = pick_first_available(GEN_CANDIDATES)
89
+ SELECTED_EMB = pick_first_available(EMB_CANDIDATES)
90
+ SELECTED_QA = pick_first_available(QA_CANDIDATES)
 
 
91
 
 
 
 
 
 
 
 
 
92
 
93
+ # ---------------------------
94
+ # Lazy model loaders
95
+ # ---------------------------
96
+ _load_lock = threading.Lock()
97
+
98
+
99
+ @torch.inference_mode()
100
+ def _to_numpy(x: torch.Tensor) -> np.ndarray:
101
+ return x.detach().cpu().numpy()
102
+
103
+
104
+ def _safe_truncate_text(s: str, max_chars: int) -> str:
105
+ s = (s or "").strip()
106
+ if len(s) > max_chars:
107
+ return s[:max_chars].rstrip() + "\n\n[Текст обрезан по лимиту длины]"
108
+ return s
109
+
110
+
111
+ def _clean_spaces(s: str) -> str:
112
+ return re.sub(r"\s+", " ", (s or "")).strip()
113
+
114
 
115
+ @torch.inference_mode()
116
+ def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
117
+ # E5-style mean pooling
118
+ mask = attention_mask.unsqueeze(-1).bool()
119
+ masked = last_hidden_states.masked_fill(~mask, 0.0)
120
+ summed = masked.sum(dim=1)
121
+ denom = attention_mask.sum(dim=1).clamp(min=1).unsqueeze(-1)
122
+ return summed / denom
123
+
124
+
125
+ @dataclass
126
+ class Models:
127
+ gen_id: str
128
+ emb_id: str
129
+ qa_id: str
130
+
131
+
132
+ MODELS = Models(gen_id=SELECTED_GEN, emb_id=SELECTED_EMB, qa_id=SELECTED_QA)
133
+
134
+
135
+ # Cached objects (per process)
136
+ _GEN_TOK: Optional[Any] = None
137
+ _GEN_MODEL: Optional[Any] = None
138
+
139
+ _EMB_TOK: Optional[Any] = None
140
+ _EMB_MODEL: Optional[Any] = None
141
+
142
+ _QA_TOK: Optional[Any] = None
143
+ _QA_MODEL: Optional[Any] = None
144
+
145
+
146
+ def load_generator() -> Tuple[Any, Any]:
147
+ global _GEN_TOK, _GEN_MODEL
148
+ with _load_lock:
149
+ if _GEN_TOK is not None and _GEN_MODEL is not None:
150
+ return _GEN_TOK, _GEN_MODEL
151
+
152
+ tok = AutoTokenizer.from_pretrained(MODELS.gen_id, use_fast=True)
153
+ # rut5 models are T5-compatible; flan-t5 too
154
+ model = T5ForConditionalGeneration.from_pretrained(
155
+ MODELS.gen_id,
156
+ torch_dtype=torch.float32,
157
+ low_cpu_mem_usage=True,
158
+ )
159
+ model.eval()
160
+ _GEN_TOK, _GEN_MODEL = tok, model
161
+ return tok, model
162
+
163
+
164
+ def load_embedder() -> Tuple[Any, Any]:
165
+ global _EMB_TOK, _EMB_MODEL
166
+ with _load_lock:
167
+ if _EMB_TOK is not None and _EMB_MODEL is not None:
168
+ return _EMB_TOK, _EMB_MODEL
169
+
170
+ tok = AutoTokenizer.from_pretrained(MODELS.emb_id, use_fast=True)
171
+ model = AutoModel.from_pretrained(
172
+ MODELS.emb_id,
173
+ torch_dtype=torch.float32,
174
+ low_cpu_mem_usage=True,
175
+ )
176
+ model.eval()
177
+ _EMB_TOK, _EMB_MODEL = tok, model
178
+ return tok, model
179
+
180
+
181
+ def load_qa() -> Tuple[Any, Any]:
182
+ global _QA_TOK, _QA_MODEL
183
+ with _load_lock:
184
+ if _QA_TOK is not None and _QA_MODEL is not None:
185
+ return _QA_TOK, _QA_MODEL
186
+
187
+ tok = AutoTokenizer.from_pretrained(MODELS.qa_id, use_fast=True)
188
+ model = AutoModelForQuestionAnswering.from_pretrained(
189
+ MODELS.qa_id,
190
+ torch_dtype=torch.float32,
191
+ low_cpu_mem_usage=True,
192
+ )
193
+ model.eval()
194
+ _QA_TOK, _QA_MODEL = tok, model
195
+ return tok, model
196
+
197
+
198
+ # ---------------------------
199
+ # Text chunking / indexing
200
+ # ---------------------------
201
+ _SENT_SPLIT = re.compile(r"(?<=[\.\!\?…])\s+|\n+")
202
+
203
+
204
+ def split_into_chunks(text: str,
205
+ target_chars: int = CHUNK_TARGET_CHARS,
206
+ max_chunks: int = MAX_CHUNKS) -> List[str]:
207
+ text = _safe_truncate_text(text, MAX_INPUT_CHARS)
208
+
209
+ # Prefer paragraph-based chunks first
210
+ paras = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()]
211
+ chunks: List[str] = []
212
+
213
+ buf = ""
214
+ for p in paras:
215
+ if not buf:
216
+ buf = p
217
+ continue
218
+ if len(buf) + 2 + len(p) <= target_chars:
219
+ buf = buf + "\n\n" + p
220
+ else:
221
+ chunks.append(buf.strip())
222
+ buf = p
223
+ if len(chunks) >= max_chunks:
224
+ break
225
+
226
+ if buf and len(chunks) < max_chunks:
227
+ chunks.append(buf.strip())
228
+
229
+ # If still too big chunks (single huge para), split by sentences
230
+ fixed: List[str] = []
231
+ for c in chunks:
232
+ if len(c) <= target_chars * 1.6:
233
+ fixed.append(c)
234
+ continue
235
+ sents = [s.strip() for s in _SENT_SPLIT.split(c) if s.strip()]
236
+ b = ""
237
+ for s in sents:
238
+ if not b:
239
+ b = s
240
  continue
241
+ if len(b) + 1 + len(s) <= target_chars:
242
+ b = b + " " + s
243
+ else:
244
+ fixed.append(b.strip())
245
+ b = s
246
+ if len(fixed) >= max_chunks:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  break
248
+ if b and len(fixed) < max_chunks:
249
+ fixed.append(b.strip())
250
+ if len(fixed) >= max_chunks:
251
+ break
252
+
253
+ return fixed[:max_chunks]
254
+
255
+
256
+ @torch.inference_mode()
257
+ def embed_texts(texts: List[str], is_query: bool) -> np.ndarray:
258
+ tok, model = load_embedder()
259
+
260
+ # E5 expects prefixes
261
+ prefix = "query: " if is_query else "passage: "
262
+ texts = [prefix + _clean_spaces(t) for t in texts]
263
+
264
+ all_vecs: List[np.ndarray] = []
265
+ for i in range(0, len(texts), EMBED_BATCH):
266
+ batch = texts[i:i + EMBED_BATCH]
267
+ enc = tok(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
268
+ out = model(**enc)
269
+ pooled = average_pool(out.last_hidden_state, enc["attention_mask"])
270
+ pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
271
+ all_vecs.append(_to_numpy(pooled))
272
+
273
+ return np.vstack(all_vecs).astype(np.float32)
274
+
275
+
276
+ def cosine_topk(query_vec: np.ndarray, matrix: np.ndarray, k: int) -> List[Tuple[int, float]]:
277
+ # query_vec: [d], matrix: [n,d], both normalized
278
+ scores = matrix @ query_vec.reshape(-1, 1)
279
+ scores = scores.squeeze(1)
280
+ if len(scores) == 0:
281
+ return []
282
+ k = max(1, min(k, len(scores)))
283
+ idx = np.argpartition(-scores, k - 1)[:k]
284
+ idx = idx[np.argsort(-scores[idx])]
285
+ return [(int(i), float(scores[i])) for i in idx]
286
+
287
+
288
+ def build_index(text: str) -> Dict[str, Any]:
289
+ text = _safe_truncate_text(text, MAX_INPUT_CHARS)
290
+ chunks = split_into_chunks(text)
291
+ if not chunks:
292
+ return {"text": text, "chunks": [], "emb": None}
293
+ emb = embed_texts(chunks, is_query=False)
294
+ return {"text": text, "chunks": chunks, "emb": emb}
295
+
296
+
297
+ def ensure_index(state: Optional[Dict[str, Any]], text: str) -> Dict[str, Any]:
298
+ text = _safe_truncate_text(text, MAX_INPUT_CHARS)
299
+ if not state or state.get("text") != text:
300
+ return build_index(text)
301
+ return state
302
+
303
+
304
+ def retrieve(state: Dict[str, Any], query: str, k: int = 5) -> List[Tuple[float, str]]:
305
+ query = (query or "").strip()
306
+ if not query or not state.get("chunks") or state.get("emb") is None:
307
+ return []
308
+ qv = embed_texts([query], is_query=True)[0]
309
+ top = cosine_topk(qv, state["emb"], k=k)
310
+ out: List[Tuple[float, str]] = []
311
+ for idx, score in top:
312
+ out.append((score, state["chunks"][idx]))
313
+ return out
314
+
315
+
316
+ # ---------------------------
317
+ # Generator (ruT5 multitask)
318
+ # ---------------------------
319
+ @torch.inference_mode()
320
+ def rut5(task: str, text: str, max_new_tokens: int = GEN_MAX_NEW_TOKENS,
321
+ do_sample: bool = False, temperature: float = 0.9, top_p: float = 0.95) -> str:
322
+ tok, model = load_generator()
323
+
324
+ task = (task or "").strip()
325
+ if task:
326
+ prompt = f"{task} | {text}"
327
  else:
328
+ prompt = text
329
+
330
+ enc = tok(prompt, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ gen_kwargs = dict(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  max_new_tokens=max_new_tokens,
334
+ num_beams=4 if not do_sample else 1,
335
+ do_sample=do_sample,
336
+ temperature=temperature if do_sample else None,
337
+ top_p=top_p if do_sample else None,
338
+ repetition_penalty=1.05,
339
+ no_repeat_ngram_size=3,
340
  )
341
+ # Remove None
342
+ gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
+ out = model.generate(**enc, **gen_kwargs)
345
+ s = tok.decode(out[0], skip_special_tokens=True).strip()
346
+ return s
 
 
 
347
 
 
348
 
349
+ # ---------------------------
350
+ # Extractive QA (mBERT xquad)
351
+ # ---------------------------
352
+ @torch.inference_mode()
353
+ def extractive_qa(question: str, context: str) -> Tuple[str, str]:
354
+ """
355
+ Returns: (answer, evidence_snippet)
356
+ """
357
+ question = (question or "").strip()
358
+ context = (context or "").strip()
359
+ if not question or not context:
360
+ return "", ""
361
+
362
+ tok, model = load_qa()
363
+
364
+ context = _safe_truncate_text(context, MAX_CONTEXT_CHARS)
365
+
366
+ enc = tok(
367
+ question,
368
+ context,
369
+ truncation="only_second",
370
+ max_length=QA_MAX_LENGTH,
371
+ stride=QA_STRIDE,
372
+ return_overflowing_tokens=True,
373
+ return_offsets_mapping=True,
374
+ padding=True,
375
+ return_tensors="pt",
376
+ )
377
 
378
+ offset_mapping = enc.pop("offset_mapping") # [features, seq]
379
+ input_ids = enc["input_ids"]
 
 
 
380
 
381
+ outputs = model(**enc)
382
+ start_logits = outputs.start_logits
383
+ end_logits = outputs.end_logits
 
384
 
385
+ best_score = -1e9
386
+ best_span = (0, 0)
387
+ best_context = context
388
 
389
+ for i in range(input_ids.shape[0]):
390
+ # sequence ids: None, 0(question), 1(context)
391
+ seq_ids = tok.sequence_ids(i)
392
+ offsets = offset_mapping[i].tolist()
393
 
394
+ # valid context token indices
395
+ context_token_idxs = [j for j, sid in enumerate(seq_ids) if sid == 1 and offsets[j] != [0, 0]]
396
+ if not context_token_idxs:
397
+ continue
 
 
 
 
 
 
398
 
399
+ s_logits = start_logits[i].detach().cpu().numpy()
400
+ e_logits = end_logits[i].detach().cpu().numpy()
401
+
402
+ for s_idx in context_token_idxs:
403
+ for e_idx in context_token_idxs:
404
+ if e_idx < s_idx:
405
+ continue
406
+ if e_idx - s_idx > 40:
407
+ continue
408
+ score = float(s_logits[s_idx] + e_logits[e_idx])
409
+ if score > best_score:
410
+ s_char, _ = offsets[s_idx]
411
+ _, e_char = offsets[e_idx]
412
+ if e_char > s_char:
413
+ best_score = score
414
+ best_span = (s_char, e_char)
415
+
416
+ ans = best_context[best_span[0]:best_span[1]].strip()
417
+ if not ans:
418
+ return "", ""
419
+
420
+ # Evidence snippet with small window
421
+ a, b = best_span
422
+ left = max(0, a - 120)
423
+ right = min(len(best_context), b + 120)
424
+ snippet = best_context[left:right].strip()
425
+
426
+ return ans, snippet
427
+
428
+
429
+ # ---------------------------
430
+ # Product features
431
+ # ---------------------------
432
+ def make_summary(state: Dict[str, Any], level: str) -> str:
433
+ chunks = state.get("chunks") or []
434
+ emb = state.get("emb")
435
+ if not chunks or emb is None:
436
+ return "Нет текста для обработки."
437
+
438
+ # central chunks via centroid similarity
439
+ centroid = emb.mean(axis=0)
440
+ centroid = centroid / (np.linalg.norm(centroid) + 1e-12)
441
+ sims = emb @ centroid.reshape(-1, 1)
442
+ sims = sims.squeeze(1)
443
+
444
+ k = 3 if level == "Коротко" else 6
445
+ k = min(k, len(chunks))
446
+ idx = np.argpartition(-sims, k - 1)[:k]
447
+ idx = idx[np.argsort(-sims[idx])]
448
+ selected = "\n\n".join(chunks[i] for i in idx.tolist())
449
+
450
+ selected = _safe_truncate_text(selected, 3000)
451
+
452
+ # title + simplified digest
453
+ title = rut5("headline", selected, max_new_tokens=32)
454
+ digest = rut5("simplify", selected, max_new_tokens=GEN_MAX_NEW_TOKENS)
455
+
456
+ # if generator fails, return extractive selection
457
+ if not digest:
458
+ digest = selected
459
+
460
+ return f"### Заголовок\n{title}\n\n### Пересказ\n{digest}"
461
+
462
+
463
+ def make_quiz(state: Dict[str, Any], n: int, difficulty: str) -> str:
464
+ chunks = state.get("chunks") or []
465
+ emb = state.get("emb")
466
+ if not chunks or emb is None:
467
+ return "Нет текста для генерации вопросов."
468
+
469
+ n = int(max(1, min(n, 12)))
470
+
471
+ # pick diverse chunks: take top by centrality, then spread
472
+ centroid = emb.mean(axis=0)
473
+ centroid = centroid / (np.linalg.norm(centroid) + 1e-12)
474
+ sims = (emb @ centroid.reshape(-1, 1)).squeeze(1)
475
+ order = np.argsort(-sims).tolist()
476
+
477
+ # take every step to diversify
478
+ step = max(1, len(order) // max(1, n))
479
+ chosen_idx = []
480
+ for i in range(0, len(order), step):
481
+ chosen_idx.append(order[i])
482
+ if len(chosen_idx) >= n:
483
+ break
484
+
485
+ questions: List[Tuple[str, str, str]] = []
486
+ seen = set()
487
+
488
+ for idx in chosen_idx:
489
+ ctx = chunks[idx]
490
+ ctx_short = _safe_truncate_text(ctx, 2000)
491
+
492
+ # generate question
493
+ q = rut5("ask", ctx_short, max_new_tokens=64, do_sample=True,
494
+ temperature=0.85 if difficulty == "Легко" else 1.0,
495
+ top_p=0.92)
496
+ q = q.strip()
497
+ q = q if q.endswith("?") else (q + "?") if q else ""
498
+
499
+ if not q or q.lower() in seen:
500
+ continue
501
+ seen.add(q.lower())
502
+
503
+ # answer from QA model (extractive) with evidence
504
+ ans, ev = extractive_qa(q, ctx_short)
505
+
506
+ # fallback to generative "comprehend" if extractive fails
507
+ if not ans:
508
+ ans = rut5("comprehend", f"{ctx_short} Вопрос: {q}", max_new_tokens=64).strip()
509
+ ev = ctx_short[:260].strip()
510
+
511
+ questions.append((q, ans, ev))
512
+ if len(questions) >= n:
513
+ break
514
+
515
+ if not questions:
516
+ return "Не удалось сгенерировать вопросы. Попробуйте увеличить текст или выбрать другой фрагмент."
517
+
518
+ # format
519
+ out = ["### Вопросы для самопроверки\n"]
520
+ for i, (q, a, ev) in enumerate(questions, 1):
521
+ out.append(f"**{i}. {q}**")
522
+ out.append(f"- Ответ: {a}")
523
+ out.append(f"- Фрагмент: {ev}")
524
+ out.append("")
525
+ return "\n".join(out).strip()
526
+
527
+
528
+ def answer_question(state: Dict[str, Any], question: str) -> str:
529
+ question = (question or "").strip()
530
+ if not question:
531
+ return "Введите вопрос."
532
+
533
+ hits = retrieve(state, question, k=4)
534
+ if not hits:
535
+ return "Нечего искать: сначала вставьте текст и нажмите «Проиндексировать»."
536
+
537
+ # Build context from top passages
538
+ context_parts = []
539
+ for score, chunk in hits:
540
+ context_parts.append(chunk)
541
+ context = "\n\n".join(context_parts)
542
+ context = _safe_truncate_text(context, MAX_CONTEXT_CHARS)
543
+
544
+ ans, ev = extractive_qa(question, context)
545
+ if not ans:
546
+ # fallback to ruT5 open-book QA
547
+ ans = rut5("comprehend", f"{context} Вопрос: {question}", max_new_tokens=96).strip()
548
+ ev = context[:320].strip()
549
+
550
+ return f"**Ответ:** {ans}\n\n**Доказательство (фрагмент текста):**\n{ev}"
551
+
552
+
553
+ def search_passages(state: Dict[str, Any], query: str, k: int) -> str:
554
+ query = (query or "").strip()
555
+ if not query:
556
+ return "Введите запрос."
557
+ hits = retrieve(state, query, k=int(max(1, min(k, 10))))
558
+ if not hits:
559
+ return "Ничего не найдено."
560
+
561
+ out = ["### Результаты семантического поиска\n"]
562
+ for i, (score, chunk) in enumerate(hits, 1):
563
+ out.append(f"**{i}. score={score:.3f}**")
564
+ out.append(chunk)
565
+ out.append("")
566
+ return "\n".join(out).strip()
567
+
568
+
569
+ # ---------------------------
570
+ # Gradio UI
571
+ # ---------------------------
572
+ def model_status_text() -> str:
573
+ return (
574
+ "Выбранные модели:\n"
575
+ f"- Генерация: {MODELS.gen_id}\n"
576
+ f"- Эмбеддинги: {MODELS.emb_id}\n"
577
+ f"- QA (extractive): {MODELS.qa_id}\n"
578
+ "\nПримечание: модели скачиваются при первом обращении."
579
  )
580
 
 
 
 
581
 
582
+ def on_index(text: str, state: Optional[Dict[str, Any]]) -> Tuple[str, Dict[str, Any]]:
583
+ text = _safe_truncate_text(text, MAX_INPUT_CHARS)
584
+ if not text.strip():
585
+ return "Пустой текст.", {"text": "", "chunks": [], "emb": None}
586
 
587
+ t0 = time.time()
588
+ st = build_index(text)
589
+ dt = time.time() - t0
 
590
 
591
+ chunks_n = len(st.get("chunks") or [])
592
+ return f"Готово: чанков={chunks_n}, индекс построен за {dt:.1f}с.", st
593
 
 
 
 
 
594
 
595
+ def on_summary(text: str, state: Optional[Dict[str, Any]], level: str) -> Tuple[str, Dict[str, Any]]:
596
+ st = ensure_index(state, text)
597
+ return make_summary(st, level), st
598
+
599
 
600
+ def on_quiz(text: str, state: Optional[Dict[str, Any]], n: int, difficulty: str) -> Tuple[str, Dict[str, Any]]:
601
+ st = ensure_index(state, text)
602
+ return make_quiz(st, n, difficulty), st
603
 
 
 
 
 
 
 
604
 
605
+ def on_search(text: str, state: Optional[Dict[str, Any]], query: str, k: int) -> Tuple[str, Dict[str, Any]]:
606
+ st = ensure_index(state, text)
607
+ return search_passages(st, query, k), st
 
 
608
 
 
 
 
609
 
610
+ def on_chat(text: str, state: Optional[Dict[str, Any]], chat: List[Tuple[str, str]], user_q: str) -> Tuple[List[Tuple[str, str]], Dict[str, Any], str]:
611
+ st = ensure_index(state, text)
612
+ user_q = (user_q or "").strip()
613
+ if not user_q:
614
+ return chat, st, ""
615
+ a = answer_question(st, user_q)
616
+ chat = (chat or []) + [(user_q, a)]
617
+ return chat, st, ""
618
 
 
 
 
619
 
620
+ with gr.Blocks(title="Text Study Assistant (CPU, 3 Transformers)") as demo:
621
+ gr.Markdown("## Text Study Assistant\nМини-помощник для конспекта, самопроверки и вопросов по тексту. CPU-only, без GPU.")
622
 
623
+ with gr.Row():
624
+ with gr.Column(scale=2):
625
+ src_text = gr.Textbox(
626
+ label="Текст",
627
+ lines=12,
628
+ placeholder="Вставьте сюда текст для анализа (лекция, статья, конспект).",
629
+ )
630
+ with gr.Row():
631
+ btn_index = gr.Button("Проиндексировать", variant="primary")
632
+ index_status = gr.Textbox(label="Статус", value="Ожидаю текст…", interactive=False)
633
+
634
+ with gr.Accordion("Модели", open=False):
635
+ gr.Textbox(value=model_status_text(), lines=6, interactive=False, show_label=False)
636
+
637
+ with gr.Column(scale=3):
638
+ state = gr.State({"text": "", "chunks": [], "emb": None})
639
+
640
+ with gr.Tabs():
641
+ with gr.Tab("Пересказ"):
642
+ level = gr.Radio(["Коротко", "Подробнее"], value="Коротко", label="Уровень")
643
+ btn_sum = gr.Button("Сделать пересказ")
644
+ sum_out = gr.Markdown()
645
+
646
+ with gr.Tab("Вопросы"):
647
+ with gr.Row():
648
+ q_n = gr.Slider(1, 12, value=6, step=1, label="Количество вопросов")
649
+ q_diff = gr.Radio(["Легко", "Сложнее"], value="Легко", label="Сложность")
650
+ btn_quiz = gr.Button("Сгенерировать вопросы")
651
+ quiz_out = gr.Markdown()
652
+
653
+ with gr.Tab("Чат по тексту"):
654
+ chat = gr.Chatbot(label="Диалог", height=380)
655
+ with gr.Row():
656
+ user_q = gr.Textbox(label="Вопрос", placeholder="Спросите что-то по тексту…", lines=1)
657
+ btn_send = gr.Button("Отправить")
658
+ gr.Markdown(
659
+ "Ответ формируется так: семантический поиск по чанкам → extractive QA с фрагментом-доказательством → fallback на ruT5 при необходимости."
660
+ )
661
+
662
+ with gr.Tab("Семантический поиск"):
663
+ with gr.Row():
664
+ search_q = gr.Textbox(label="Запрос", placeholder="Например: 'основная гипотеза' или 'методика эксперимента'")
665
+ topk = gr.Slider(1, 10, value=5, step=1, label="Топ-K")
666
+ btn_search = gr.Button("Найти фрагменты")
667
+ search_out = gr.Markdown()
668
+
669
+ # Wiring
670
+ btn_index.click(on_index, inputs=[src_text, state], outputs=[index_status, state])
671
+
672
+ btn_sum.click(on_summary, inputs=[src_text, state, level], outputs=[sum_out, state])
673
+ btn_quiz.click(on_quiz, inputs=[src_text, state, q_n, q_diff], outputs=[quiz_out, state])
674
+ btn_search.click(on_search, inputs=[src_text, state, search_q, topk], outputs=[search_out, state])
675
+
676
+ btn_send.click(on_chat, inputs=[src_text, state, chat, user_q], outputs=[chat, state, user_q])
677
+ user_q.submit(on_chat, inputs=[src_text, state, chat, user_q], outputs=[chat, state, user_q])
678
 
 
679
 
680
  if __name__ == "__main__":
681
+ demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860, show_error=True)