skatzR commited on
Commit
21aa3b6
·
verified ·
1 Parent(s): 5d2179c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +270 -27
inference.py CHANGED
@@ -1,13 +1,34 @@
1
  # requirements
2
- !pip install torch==2.8.0 torchvision==0.17.2
3
  !pip install transformers==4.48.3 tokenizers sentencepiece accelerate
4
 
5
 
 
 
 
 
 
 
 
 
6
  import torch
7
  from typing import List, Optional
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ERROR_NAMES_RU = {
12
  "false_causality": "Ложная причинно-следственная связь",
13
  "unsupported_claim": "Неподкрепленное утверждение",
@@ -18,6 +39,10 @@ ERROR_NAMES_RU = {
18
  }
19
 
20
 
 
 
 
 
21
  class RQAJudge:
22
  def __init__(self, model_name="skatzR/RQA-R2", device=None, max_length: int = 512):
23
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
@@ -35,16 +60,24 @@ class RQAJudge:
35
  self.model.eval()
36
 
37
  cfg = self.model.config
38
- self.error_types = list(cfg.error_types)
39
 
40
- self.temp_issue = float(cfg.temperature_has_issue)
41
- self.temp_hidden = float(cfg.temperature_is_hidden)
42
- self.temp_errors = list(cfg.temperature_errors)
 
 
43
 
44
- self.threshold_issue = float(cfg.threshold_has_issue)
45
- self.threshold_hidden = float(cfg.threshold_is_hidden)
46
- self.threshold_error = float(cfg.threshold_error)
47
- self.threshold_errors = list(cfg.threshold_errors)
 
 
 
 
 
 
48
 
49
  @torch.no_grad()
50
  def infer(
@@ -73,20 +106,15 @@ class RQAJudge:
73
 
74
  outputs = self.model(**inputs)
75
 
 
76
  issue_logit = outputs["has_issue_logits"] / self.temp_issue
77
- hidden_logit = outputs["is_hidden_logits"] / self.temp_hidden
78
-
79
- error_logits = outputs["errors_logits"][0].clone()
80
- for i in range(len(self.error_types)):
81
- error_logits[i] = error_logits[i] / self.temp_errors[i]
82
-
83
  issue_prob = torch.sigmoid(issue_logit).item()
84
  has_issue = issue_prob >= issue_threshold
85
 
86
  result = {
87
  "text": text,
88
- "class": None,
89
- "status": "ok",
90
  "review_required": False,
91
  "has_issue": has_issue,
92
  "issue_probability": issue_prob,
@@ -94,21 +122,24 @@ class RQAJudge:
94
  "hidden_probability": None,
95
  "errors": [],
96
  "num_errors": 0,
 
97
  "threshold_issue": issue_threshold,
98
  "threshold_hidden": hidden_threshold,
99
  "threshold_error": error_threshold,
100
  "threshold_errors": error_thresholds,
101
- "schema_version": getattr(self.model.config, "schema_version", "unknown"),
102
  }
103
 
104
  if abs(issue_prob - issue_threshold) <= issue_uncertain_margin:
105
  result["status"] = "uncertain"
106
  result["review_required"] = True
107
 
 
108
  if not has_issue:
109
  result["class"] = "logical"
110
  return result
111
 
 
 
112
  hidden_prob = torch.sigmoid(hidden_logit).item()
113
  is_hidden = hidden_prob >= hidden_threshold
114
 
@@ -119,14 +150,23 @@ class RQAJudge:
119
  result["status"] = "uncertain"
120
  result["review_required"] = True
121
 
 
122
  if is_hidden:
123
  result["class"] = "hidden"
124
  return result
125
 
126
- error_probs = torch.sigmoid(error_logits).tolist()
127
- detected = []
 
 
 
 
 
 
 
 
128
  for i, err_name in enumerate(self.error_types):
129
- prob = float(error_probs[i])
130
  threshold_i = float(error_thresholds[i] if i < len(error_thresholds) else error_threshold)
131
 
132
  if abs(prob - threshold_i) <= error_uncertain_margin:
@@ -134,15 +174,19 @@ class RQAJudge:
134
  result["review_required"] = True
135
 
136
  if prob >= threshold_i:
137
- detected.append((err_name, prob))
138
 
139
- detected.sort(key=lambda x: x[1], reverse=True)
140
 
141
  result["class"] = "explicit"
142
- result["errors"] = detected
143
- result["num_errors"] = len(detected)
144
  return result
145
 
 
 
 
 
146
  def pretty_print(self, r):
147
  print("\n" + "=" * 72)
148
  print("📄 Текст:")
@@ -155,11 +199,11 @@ class RQAJudge:
155
  print(f"🧠 Класс: {r['class']}")
156
 
157
  if r["status"] == "uncertain":
158
- print("⚠️ Статус: uncertain")
159
 
160
  if r["hidden_probability"] is not None:
161
  print(
162
- f"🟡 Hidden: {'ДА' if r['hidden_problem'] else 'НЕТ'} "
163
  f"({r['hidden_probability'] * 100:.2f}%)"
164
  )
165
 
@@ -171,3 +215,202 @@ class RQAJudge:
171
  print("\n✅ Явных логических ошибок не обнаружено")
172
 
173
  print("=" * 72)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # requirements
2
+ # Для inference в Colab достаточно этого стека.
3
  !pip install transformers==4.48.3 tokenizers sentencepiece accelerate
4
 
5
 
6
+ # ============================================================
7
+ # RQA UX Inference — R2 Interactive Version
8
+ # Google Colab + CLI friendly
9
+ # ============================================================
10
+
11
+ import os
12
+ import json
13
+ import csv
14
  import torch
15
  from typing import List, Optional
16
  from transformers import AutoTokenizer, AutoModel
17
 
18
 
19
+ # ============================================================
20
+ # Константы
21
+ # ============================================================
22
+
23
+ ERROR_TYPES = [
24
+ "false_causality",
25
+ "unsupported_claim",
26
+ "overgeneralization",
27
+ "missing_premise",
28
+ "contradiction",
29
+ "circular_reasoning",
30
+ ]
31
+
32
  ERROR_NAMES_RU = {
33
  "false_causality": "Ложная причинно-следственная связь",
34
  "unsupported_claim": "Неподкрепленное утверждение",
 
39
  }
40
 
41
 
42
+ # ============================================================
43
+ # RQA Judge
44
+ # ============================================================
45
+
46
  class RQAJudge:
47
  def __init__(self, model_name="skatzR/RQA-R2", device=None, max_length: int = 512):
48
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
60
  self.model.eval()
61
 
62
  cfg = self.model.config
63
+ self.error_types = list(getattr(cfg, "error_types", ERROR_TYPES))
64
 
65
+ self.temp_issue = float(getattr(cfg, "temperature_has_issue", 1.0))
66
+ self.temp_hidden = float(getattr(cfg, "temperature_is_hidden", 1.0))
67
+ self.temp_errors = list(
68
+ getattr(cfg, "temperature_errors", [1.0] * len(self.error_types))
69
+ )
70
 
71
+ self.threshold_issue = float(getattr(cfg, "threshold_has_issue", 0.5))
72
+ self.threshold_hidden = float(getattr(cfg, "threshold_is_hidden", 0.5))
73
+ self.threshold_error = float(getattr(cfg, "threshold_error", 0.5))
74
+ self.threshold_errors = list(
75
+ getattr(cfg, "threshold_errors", [self.threshold_error] * len(self.error_types))
76
+ )
77
+
78
+ # ----------------------
79
+ # Core inference
80
+ # ----------------------
81
 
82
  @torch.no_grad()
83
  def infer(
 
106
 
107
  outputs = self.model(**inputs)
108
 
109
+ # ----- has_issue -----
110
  issue_logit = outputs["has_issue_logits"] / self.temp_issue
 
 
 
 
 
 
111
  issue_prob = torch.sigmoid(issue_logit).item()
112
  has_issue = issue_prob >= issue_threshold
113
 
114
  result = {
115
  "text": text,
116
+ "class": None, # logical / hidden / explicit
117
+ "status": "ok", # ok / uncertain
118
  "review_required": False,
119
  "has_issue": has_issue,
120
  "issue_probability": issue_prob,
 
122
  "hidden_probability": None,
123
  "errors": [],
124
  "num_errors": 0,
125
+ "schema_version": getattr(self.model.config, "schema_version", "unknown"),
126
  "threshold_issue": issue_threshold,
127
  "threshold_hidden": hidden_threshold,
128
  "threshold_error": error_threshold,
129
  "threshold_errors": error_thresholds,
 
130
  }
131
 
132
  if abs(issue_prob - issue_threshold) <= issue_uncertain_margin:
133
  result["status"] = "uncertain"
134
  result["review_required"] = True
135
 
136
+ # ----- Gate 1: logical -----
137
  if not has_issue:
138
  result["class"] = "logical"
139
  return result
140
 
141
+ # ----- hidden -----
142
+ hidden_logit = outputs["is_hidden_logits"] / self.temp_hidden
143
  hidden_prob = torch.sigmoid(hidden_logit).item()
144
  is_hidden = hidden_prob >= hidden_threshold
145
 
 
150
  result["status"] = "uncertain"
151
  result["review_required"] = True
152
 
153
+ # ----- Gate 2: hidden -----
154
  if is_hidden:
155
  result["class"] = "hidden"
156
  return result
157
 
158
+ # ----- explicit errors -----
159
+ raw_error_logits = outputs["errors_logits"][0].clone()
160
+ error_probs = {}
161
+
162
+ for i, logit in enumerate(raw_error_logits):
163
+ calibrated = logit / self.temp_errors[i]
164
+ prob = torch.sigmoid(calibrated).item()
165
+ error_probs[self.error_types[i]] = prob
166
+
167
+ explicit_errors = []
168
  for i, err_name in enumerate(self.error_types):
169
+ prob = float(error_probs[err_name])
170
  threshold_i = float(error_thresholds[i] if i < len(error_thresholds) else error_threshold)
171
 
172
  if abs(prob - threshold_i) <= error_uncertain_margin:
 
174
  result["review_required"] = True
175
 
176
  if prob >= threshold_i:
177
+ explicit_errors.append((err_name, prob))
178
 
179
+ explicit_errors.sort(key=lambda x: x[1], reverse=True)
180
 
181
  result["class"] = "explicit"
182
+ result["errors"] = explicit_errors
183
+ result["num_errors"] = len(explicit_errors)
184
  return result
185
 
186
+ # ============================================================
187
+ # UX output
188
+ # ============================================================
189
+
190
  def pretty_print(self, r):
191
  print("\n" + "=" * 72)
192
  print("📄 Текст:")
 
199
  print(f"🧠 Класс: {r['class']}")
200
 
201
  if r["status"] == "uncertain":
202
+ print("⚠️ Пограничный случай: review recommended")
203
 
204
  if r["hidden_probability"] is not None:
205
  print(
206
+ f"🟡 Hidden-проблема: {'ДА' if r['hidden_problem'] else 'НЕТ'} "
207
  f"({r['hidden_probability'] * 100:.2f}%)"
208
  )
209
 
 
215
  print("\n✅ Явных логических ошибок не обнаружено")
216
 
217
  print("=" * 72)
218
+
219
+
220
+ # ============================================================
221
+ # Loaders
222
+ # ============================================================
223
+
224
+ def load_texts_from_file(path: str) -> List[str]:
225
+ ext = os.path.splitext(path)[1].lower()
226
+
227
+ if ext == ".txt":
228
+ with open(path, encoding="utf-8") as f:
229
+ return [line.strip() for line in f if line.strip()]
230
+
231
+ if ext == ".csv":
232
+ with open(path, encoding="utf-8") as f:
233
+ reader = csv.DictReader(f)
234
+ return [row["text"] for row in reader if row.get("text")]
235
+
236
+ if ext == ".json":
237
+ with open(path, encoding="utf-8") as f:
238
+ data = json.load(f)
239
+ if isinstance(data, list):
240
+ if all(isinstance(item, str) for item in data):
241
+ return data
242
+ texts = []
243
+ for item in data:
244
+ if isinstance(item, dict) and "text" in item:
245
+ texts.append(str(item["text"]))
246
+ return texts
247
+
248
+ raise ValueError("Неподдерживаемый формат файла")
249
+
250
+
251
+ # ============================================================
252
+ # Interactive CLI Interface
253
+ # ============================================================
254
+
255
+ class InteractiveCLI:
256
+ def __init__(self, model_name="skatzR/RQA-R2"):
257
+ self.judge = RQAJudge(model_name=model_name)
258
+
259
+ def clear_screen(self):
260
+ print("\n" * 2)
261
+
262
+ def show_mode_menu(self):
263
+ self.clear_screen()
264
+ print("=" * 60)
265
+ print("🤖 RQA-R2 — АНАЛИЗ ЛОГИЧЕСКИХ ОШИБОК")
266
+ print("=" * 60)
267
+ print("\nВыберите режим работы:")
268
+ print("1. 📝 Одиночный ввод (одна фраза для анализа)")
269
+ print("2. 📄 Множественный ввод (несколько фраз, каждая с новой строки)")
270
+ print("3. 📂 Загрузка из файла (.txt, .csv, .json)")
271
+ print("\nНажмите Enter без ввода для выхода.")
272
+ print("-" * 60)
273
+
274
+ def process_single_mode(self):
275
+ self.clear_screen()
276
+ print("[📝 РЕЖИМ: ОДИНОЧНЫЙ ВВОД]")
277
+ print("Введите текст для анализа:")
278
+ print("(Нажмите Enter без ввода для возврата в меню)")
279
+ print("-" * 40)
280
+
281
+ text = input("> ").strip()
282
+ if not text:
283
+ return True
284
+
285
+ result = self.judge.infer(text)
286
+ self.judge.pretty_print(result)
287
+
288
+ print("\n" + "-" * 40)
289
+ input("Нажмите Enter для продолжения...")
290
+ return False
291
+
292
+ def process_multiline_mode(self):
293
+ self.clear_screen()
294
+ print("[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД]")
295
+ print("Введите тексты для анализа (каждый с новой строки).")
296
+ print("Оставьте строку пустой для завершения ввода.")
297
+ print("(Нажмите Enter без ввода для возврата в меню)")
298
+ print("-" * 40)
299
+
300
+ texts = []
301
+ print("Ввод текстов:")
302
+ while True:
303
+ line = input("> ").strip()
304
+ if not line:
305
+ if not texts:
306
+ return True
307
+ break
308
+ texts.append(line)
309
+
310
+ self.clear_screen()
311
+ print(f"[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД] — найдено {len(texts)} текстов")
312
+ print("-" * 40)
313
+
314
+ for i, text in enumerate(texts, 1):
315
+ print(f"\n🔍 Текст #{i}:")
316
+ result = self.judge.infer(text)
317
+ self.judge.pretty_print(result)
318
+
319
+ print("\n" + "=" * 60)
320
+ input("Нажмите Enter для продолжения...")
321
+ return False
322
+
323
+ def process_file_mode(self):
324
+ self.clear_screen()
325
+ print("[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА]")
326
+ print("Поддерживаемые форматы: .txt, .csv, .json")
327
+ print("Укажите путь к файлу:")
328
+ print("(Нажмите Enter без ввода для возврата в меню)")
329
+ print("-" * 40)
330
+
331
+ file_path = input("Путь к файлу> ").strip()
332
+ if not file_path:
333
+ return True
334
+
335
+ try:
336
+ if not os.path.exists(file_path):
337
+ print(f"\n❌ Ошибка: Файл '{file_path}' не найден!")
338
+ input("\nНажмите Enter для продолжения...")
339
+ return False
340
+
341
+ texts = load_texts_from_file(file_path)
342
+ if not texts:
343
+ print(f"\n⚠️ Файл '{file_path}' пуст или не содержит текстов!")
344
+ input("\nНажмите Enter для продолжения...")
345
+ return False
346
+
347
+ self.clear_screen()
348
+ print(f"[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА] — загружено {len(texts)} текстов")
349
+ print(f"Файл: {file_path}")
350
+ print("-" * 40)
351
+
352
+ for i, text in enumerate(texts, 1):
353
+ print(f"\n🔍 Текст #{i}:")
354
+ result = self.judge.infer(text)
355
+ self.judge.pretty_print(result)
356
+
357
+ print("\n" + "=" * 60)
358
+ input("Нажмите Enter для продолжения...")
359
+
360
+ except Exception as e:
361
+ print(f"\n❌ Ошибка при обработке файла: {str(e)}")
362
+ input("\nНажмите Enter для продолжения...")
363
+
364
+ return False
365
+
366
+ def run_interactive(self):
367
+ current_mode = None
368
+
369
+ while True:
370
+ if not current_mode:
371
+ self.show_mode_menu()
372
+ choice = input("Ваш выбор (1-3)> ").strip()
373
+
374
+ if not choice:
375
+ print("\n👋 Выход из программы...")
376
+ break
377
+
378
+ if choice == "1":
379
+ current_mode = "single"
380
+ elif choice == "2":
381
+ current_mode = "multiline"
382
+ elif choice == "3":
383
+ current_mode = "file"
384
+ else:
385
+ print("\n❌ Неверный выбор! Попробуйте снова.")
386
+ input("Нажмите Enter для продолжения...")
387
+ continue
388
+
389
+ should_return_to_menu = False
390
+
391
+ if current_mode == "single":
392
+ should_return_to_menu = self.process_single_mode()
393
+ elif current_mode == "multiline":
394
+ should_return_to_menu = self.process_multiline_mode()
395
+ elif current_mode == "file":
396
+ should_return_to_menu = self.process_file_mode()
397
+
398
+ if should_return_to_menu:
399
+ current_mode = None
400
+
401
+
402
+ # ============================================================
403
+ # Точка входа
404
+ # ============================================================
405
+
406
+ def main():
407
+ cli = InteractiveCLI()
408
+ cli.run_interactive()
409
+
410
+
411
+ # ============================================================
412
+ # Запуск
413
+ # ============================================================
414
+
415
+ if __name__ == "__main__":
416
+ main()