skatzR commited on
Commit
041a905
·
verified ·
1 Parent(s): 1415982

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +264 -0
inference.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # RQA UX Inference — FINAL
3
+ # Google Colab + CLI friendly
4
+ # ============================================================
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import argparse
10
+ import csv
11
+ import torch
12
+ from typing import List, Union
13
+ from transformers import AutoTokenizer, AutoModel
14
+
15
+ # ============================================================
16
+ # Константы
17
+ # ============================================================
18
+
19
+ ERROR_TYPES = [
20
+ "false_causality",
21
+ "unsupported_claim",
22
+ "overgeneralization",
23
+ "missing_premise",
24
+ "contradiction",
25
+ "circular_reasoning",
26
+ ]
27
+
28
+ ERROR_NAMES_RU = {
29
+ "false_causality": "Ложная причинно-следственная связь",
30
+ "unsupported_claim": "Неподкреплённое утверждение",
31
+ "overgeneralization": "Чрезмерное обобщение",
32
+ "missing_premise": "Отсутствующая предпосылка",
33
+ "contradiction": "Противоречие",
34
+ "circular_reasoning": "Круговое рассуждение",
35
+ }
36
+
37
+ ERROR_THRESHOLDS = {
38
+ "false_causality": 0.55,
39
+ "unsupported_claim": 0.55,
40
+ "overgeneralization": 0.60,
41
+ "missing_premise": 0.80, # диагностический
42
+ "contradiction": 0.60,
43
+ "circular_reasoning": 0.60,
44
+ }
45
+
46
+ # ============================================================
47
+ # RQA Judge
48
+ # ============================================================
49
+
50
+ class RQAJudge:
51
+ def __init__(self, model_name="skatzR/RQA-X1.1", device=None):
52
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
+ self.tokenizer = AutoTokenizer.from_pretrained(
55
+ model_name,
56
+ trust_remote_code=True
57
+ )
58
+ self.model = AutoModel.from_pretrained(
59
+ model_name,
60
+ trust_remote_code=True
61
+ ).to(self.device)
62
+
63
+ self.model.eval()
64
+
65
+ cfg = self.model.config
66
+ self.temp_issue = float(cfg.temperature_has_issue)
67
+ self.temp_errors = list(cfg.temperature_errors)
68
+
69
+ # ----------------------
70
+ # Core inference
71
+ # ----------------------
72
+
73
+ @torch.no_grad()
74
+ def infer(
75
+ self,
76
+ text: str,
77
+ issue_threshold: float = 0.6,
78
+ disagreement_threshold: float = 0.4,
79
+ ):
80
+ inputs = self.tokenizer(
81
+ text,
82
+ truncation=True,
83
+ max_length=512,
84
+ padding="max_length",
85
+ return_tensors="pt"
86
+ ).to(self.device)
87
+
88
+ outputs = self.model(**inputs)
89
+
90
+ # ----- has_issue -----
91
+ issue_logit = outputs["has_issue_logits"] / self.temp_issue
92
+ issue_prob = torch.sigmoid(issue_logit).item()
93
+ has_issue = issue_prob >= issue_threshold
94
+
95
+ # ----- errors -----
96
+ raw_error_logits = outputs["errors_logits"][0]
97
+ error_probs = {}
98
+
99
+ for i, logit in enumerate(raw_error_logits):
100
+ calibrated = logit / self.temp_errors[i]
101
+ prob = torch.sigmoid(calibrated).item()
102
+ error_probs[ERROR_TYPES[i]] = prob
103
+
104
+ # ----- disagreement -----
105
+ p_any_error = 1.0
106
+ for p in error_probs.values():
107
+ p_any_error *= (1.0 - p)
108
+ p_any_error = 1.0 - p_any_error
109
+
110
+ disagreement = abs(issue_prob - p_any_error)
111
+
112
+ # ----- decision logic -----
113
+ explicit_errors = []
114
+ hidden_problem = False
115
+
116
+ for err, prob in error_probs.items():
117
+ if prob >= ERROR_THRESHOLDS[err]:
118
+ if err == "missing_premise":
119
+ hidden_problem = True
120
+ else:
121
+ explicit_errors.append((err, prob))
122
+
123
+ explicit_errors.sort(key=lambda x: x[1], reverse=True)
124
+
125
+ # бинарная голова доминирует
126
+ if not has_issue:
127
+ explicit_errors = []
128
+
129
+ borderline = (
130
+ not has_issue and hidden_problem and disagreement >= disagreement_threshold
131
+ )
132
+
133
+ return {
134
+ "text": text,
135
+ "has_issue": has_issue,
136
+ "issue_probability": issue_prob,
137
+ "errors": explicit_errors,
138
+ "hidden_problem": hidden_problem,
139
+ "borderline": borderline,
140
+ "disagreement": disagreement,
141
+ }
142
+
143
+ # ============================================================
144
+ # UX output
145
+ # ============================================================
146
+
147
+ def pretty_print(self, r):
148
+ print("\n" + "=" * 72)
149
+ print("📄 Текст:")
150
+ print(r["text"])
151
+
152
+ print(f"\n🔎 Обнаружена проблема: {'ДА' if r['has_issue'] else 'НЕТ'} "
153
+ f"({r['issue_probability']*100:.2f}%)")
154
+
155
+ if r["borderline"]:
156
+ print("⚠️ Пограничный случай: аргументативный текст")
157
+
158
+ if r["hidden_problem"]:
159
+ print("🟡 Скрытая проблема: возможны неявные предпосылки")
160
+
161
+ if r["errors"]:
162
+ print("\n❌ Явные логические ошибки:")
163
+ for name, prob in r["errors"]:
164
+ print(f" • {ERROR_NAMES_RU[name]} — {prob*100:.2f}%")
165
+ else:
166
+ print("\n✅ Явных логических ошибок не обнаружено")
167
+
168
+ print(f"\n📊 Disagreement: {r['disagreement']:.3f}")
169
+ print("=" * 72)
170
+
171
+ # ============================================================
172
+ # Loaders
173
+ # ============================================================
174
+
175
+ def load_texts_from_file(path: str) -> List[str]:
176
+ ext = os.path.splitext(path)[1].lower()
177
+
178
+ if ext == ".txt":
179
+ with open(path, encoding="utf-8") as f:
180
+ return [l.strip() for l in f if l.strip()]
181
+
182
+ if ext == ".csv":
183
+ with open(path, encoding="utf-8") as f:
184
+ reader = csv.DictReader(f)
185
+ return [row["text"] for row in reader]
186
+
187
+ if ext == ".json":
188
+ with open(path, encoding="utf-8") as f:
189
+ data = json.load(f)
190
+ if isinstance(data, list):
191
+ return data
192
+
193
+ raise ValueError("Неподдерживаемый формат файла")
194
+
195
+ # ============================================================
196
+ # CLI / Colab entrypoint
197
+ # ============================================================
198
+
199
+ def main():
200
+ parser = argparse.ArgumentParser(
201
+ description="RQA — анализ логических ошибок"
202
+ )
203
+
204
+ parser.add_argument(
205
+ "--text",
206
+ type=str,
207
+ help="Один текст для анализа"
208
+ )
209
+
210
+ parser.add_argument(
211
+ "--file",
212
+ type=str,
213
+ help="Файл с текстами (.txt, .csv, .json)"
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--multiline",
218
+ action="store_true",
219
+ help="Ввод нескольких строк (каждая строка — отдельный текст)"
220
+ )
221
+
222
+ args, unknown = parser.parse_known_args()
223
+
224
+ judge = RQAJudge()
225
+
226
+ texts = []
227
+
228
+ # ---------- FILE MODE ----------
229
+ if args.file:
230
+ if not os.path.exists(args.file):
231
+ raise FileNotFoundError(args.file)
232
+ texts = load_texts_from_file(args.file)
233
+
234
+ # ---------- SINGLE TEXT ----------
235
+ elif args.text:
236
+ texts = [args.text]
237
+
238
+ # ---------- MULTILINE ----------
239
+ elif args.multiline:
240
+ print("Введите тексты (пустая строка — конец ввода):")
241
+ while True:
242
+ line = input()
243
+ if not line.strip():
244
+ break
245
+ texts.append(line.strip())
246
+
247
+ # ---------- INTERACTIVE FALLBACK ----------
248
+ else:
249
+ print("Введите текст для анализа:")
250
+ line = input().strip()
251
+ if line:
252
+ texts = [line]
253
+ else:
254
+ print("❌ Пустой ввод — выхожу")
255
+ return
256
+
257
+ # ---------- RUN ----------
258
+ for t in texts:
259
+ result = judge.infer(t)
260
+ judge.pretty_print(result)
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()