Kevinshh commited on
Commit
11452a6
·
verified ·
1 Parent(s): 2cfa897

Upload stability_data_extractor.py

Browse files
Files changed (1) hide show
  1. utils/stability_data_extractor.py +374 -0
utils/stability_data_extractor.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stability Data Extractor - LLM-Powered Format-Agnostic Extraction
3
+
4
+ This module uses LLM to understand and extract stability data from ANY file format.
5
+ No hardcoded patterns - the LLM interprets the data structure dynamically.
6
+ """
7
+
8
+ import json
9
+ import re
10
+ from typing import Dict, List, Any, Optional
11
+ from pathlib import Path
12
+
13
+
14
+ class StabilityDataExtractor:
15
+ """
16
+ LLM-powered stability data extractor.
17
+
18
+ Handles ANY format by using LLM to understand the data structure.
19
+ Falls back to heuristic extraction if LLM is unavailable.
20
+ """
21
+
22
+ # Prompt for LLM to extract structured stability data
23
+ EXTRACTION_PROMPT = """你是药物稳定性数据提取专家。请从以下文本中提取稳定性数据。
24
+
25
+ 【数据内容】
26
+ {text_content}
27
+
28
+ 【用户分析目标】
29
+ {goal}
30
+
31
+ 【任务】
32
+ 请识别并提取:
33
+ 1. 批次信息(批次名称/ID)
34
+ 2. 存储条件(如25°C/60%RH长期, 40°C/75%RH加速等)
35
+ 3. 时间点(月)
36
+ 4. 质量指标数值(如杂质含量、含量等)
37
+
38
+ 【输出格式】
39
+ 请严格按以下JSON格式输出:
40
+ ```json
41
+ {{
42
+ "batches": [
43
+ {{
44
+ "batch_id": "批次ID",
45
+ "batch_name": "批次名称",
46
+ "conditions": [
47
+ {{
48
+ "condition_id": "条件描述(如 25C_60RH)",
49
+ "condition_type": "longterm|accelerated|stress",
50
+ "timepoints": [0, 3, 6, 9],
51
+ "cqa_data": [
52
+ {{
53
+ "cqa_name": "指标名称(如总杂质)",
54
+ "values": [0.1, 0.12, 0.15, 0.18]
55
+ }}
56
+ ]
57
+ }}
58
+ ]
59
+ }}
60
+ ],
61
+ "specification_limit": 0.5,
62
+ "primary_cqa": "主要质量指标名称"
63
+ }}
64
+ ```
65
+
66
+ 如果无法识别数据结构,请返回空的batches数组并在"extraction_notes"字段说明原因。
67
+ """
68
+
69
+ def __init__(self, model_invoker=None):
70
+ """
71
+ Initialize extractor.
72
+
73
+ Args:
74
+ model_invoker: LLM invoker instance (lazy-loaded if not provided)
75
+ """
76
+ self._model_invoker = model_invoker
77
+ self.extracted_data = {}
78
+ self.metadata = {}
79
+
80
+ @property
81
+ def model_invoker(self):
82
+ """Lazy-load model invoker."""
83
+ if self._model_invoker is None:
84
+ try:
85
+ from layers.model_invoker import ModelInvoker
86
+ self._model_invoker = ModelInvoker()
87
+ except Exception as e:
88
+ print(f"Warning: Could not load ModelInvoker: {e}")
89
+ self._model_invoker = None
90
+ return self._model_invoker
91
+
92
+ def extract_from_text(self, text_content: str, goal: str = "") -> Dict[str, Any]:
93
+ """
94
+ Extract stability data from text using LLM.
95
+
96
+ Args:
97
+ text_content: Raw text from parsed files
98
+ goal: Analysis goal from user
99
+
100
+ Returns:
101
+ Structured data dictionary with batches
102
+ """
103
+ # Default result structure
104
+ result = {
105
+ "batches": [],
106
+ "specification_limit": 0.5,
107
+ "primary_cqa": "总杂质",
108
+ "target_timepoints": self._extract_target_timepoints(goal),
109
+ "extraction_method": "none",
110
+ "extraction_notes": ""
111
+ }
112
+
113
+ if not text_content or len(text_content.strip()) < 50:
114
+ result["extraction_notes"] = "文本内容过短,无法提取数据"
115
+ return result
116
+
117
+ # Try LLM extraction first
118
+ llm_result = self._extract_with_llm(text_content, goal)
119
+ if llm_result and llm_result.get("batches"):
120
+ result.update(llm_result)
121
+ result["extraction_method"] = "llm"
122
+ return result
123
+
124
+ # Fallback to heuristic extraction
125
+ heuristic_result = self._extract_with_heuristics(text_content, goal)
126
+ if heuristic_result and heuristic_result.get("batches"):
127
+ result.update(heuristic_result)
128
+ result["extraction_method"] = "heuristic"
129
+ return result
130
+
131
+ result["extraction_notes"] = "无法识别数据格式,请确保文件包含时间点和数值数据"
132
+ return result
133
+
134
+ def _extract_with_llm(self, text_content: str, goal: str) -> Optional[Dict]:
135
+ """Use LLM to extract structured data."""
136
+ if not self.model_invoker:
137
+ return None
138
+
139
+ try:
140
+ # Truncate text to avoid token limits
141
+ max_chars = 8000
142
+ truncated_text = text_content[:max_chars]
143
+ if len(text_content) > max_chars:
144
+ truncated_text += "\n... [文本已截断]"
145
+
146
+ prompt = self.EXTRACTION_PROMPT.format(
147
+ text_content=truncated_text,
148
+ goal=goal or "分析稳定性数据"
149
+ )
150
+
151
+ response = self.model_invoker.invoke(
152
+ system_prompt="你是专业的药物稳定性数据提取助手。请从文本中提取结构化的稳定性数据。",
153
+ user_prompt=prompt,
154
+ temperature=0.1
155
+ )
156
+
157
+ if response and hasattr(response, 'content'):
158
+ content = response.content
159
+ elif isinstance(response, str):
160
+ content = response
161
+ else:
162
+ return None
163
+
164
+ # Parse JSON from response
165
+ json_match = re.search(r'```json\s*([\s\S]*?)\s*```', content)
166
+ if json_match:
167
+ json_str = json_match.group(1)
168
+ else:
169
+ # Try to find raw JSON
170
+ json_str = content.strip()
171
+ if not json_str.startswith('{'):
172
+ return None
173
+
174
+ extracted = json.loads(json_str)
175
+ return extracted
176
+
177
+ except Exception as e:
178
+ print(f"LLM extraction failed: {e}")
179
+ return None
180
+
181
+ def _extract_with_heuristics(self, text_content: str, goal: str) -> Optional[Dict]:
182
+ """
183
+ Fallback heuristic extraction using pattern recognition.
184
+ More flexible than previous hardcoded approach.
185
+ """
186
+ batches = []
187
+
188
+ # Find all tables with time-series data
189
+ tables = self._find_time_series_tables(text_content)
190
+
191
+ for i, table in enumerate(tables):
192
+ batch = self._create_batch_from_table(table, i, text_content)
193
+ if batch:
194
+ batches.append(batch)
195
+
196
+ if batches:
197
+ return {"batches": batches}
198
+
199
+ return None
200
+
201
+ def _find_time_series_tables(self, text: str) -> List[Dict]:
202
+ """Find patterns that look like time-series data tables."""
203
+ tables = []
204
+ lines = text.split('\n')
205
+
206
+ # Pattern for time headers: 0M, 3M, 6M or 0月, 3月 etc.
207
+ time_pattern = r'(\d+)\s*[MmHhDd月周天]'
208
+
209
+ for i, line in enumerate(lines):
210
+ time_matches = re.findall(time_pattern, line)
211
+ if len(time_matches) >= 2:
212
+ # Found potential time header
213
+ times = [int(t) for t in time_matches]
214
+
215
+ # Look for data rows below
216
+ data_rows = []
217
+ for j in range(i+1, min(i+15, len(lines))):
218
+ numbers = re.findall(r'(\d+\.?\d*)', lines[j])
219
+ if len(numbers) >= len(times):
220
+ # Check if numbers are in plausible range for stability data
221
+ try:
222
+ values = [float(n) for n in numbers[:len(times)]]
223
+ if all(0 <= v <= 200 for v in values): # Reasonable range
224
+ row_type = self._identify_row_type(lines[j])
225
+ data_rows.append({
226
+ "values": values,
227
+ "type": row_type,
228
+ "raw": lines[j]
229
+ })
230
+ except:
231
+ pass
232
+
233
+ if data_rows:
234
+ # Find context for batch/condition identification
235
+ context_start = max(0, i - 10)
236
+ context = '\n'.join(lines[context_start:i+1])
237
+
238
+ tables.append({
239
+ "times": times,
240
+ "rows": data_rows,
241
+ "context": context,
242
+ "line_number": i
243
+ })
244
+
245
+ return tables
246
+
247
+ def _identify_row_type(self, line: str) -> str:
248
+ """Identify what type of measurement a row represents."""
249
+ line_lower = line.lower()
250
+
251
+ if any(kw in line_lower for kw in ['杂质', 'impurity', '杂']):
252
+ return 'impurity'
253
+ elif any(kw in line_lower for kw in ['含量', 'assay', 'content']):
254
+ return 'assay'
255
+ elif any(kw in line_lower for kw in ['水分', 'moisture', 'water']):
256
+ return 'moisture'
257
+ elif any(kw in line_lower for kw in ['溶出', 'dissolution']):
258
+ return 'dissolution'
259
+
260
+ return 'unknown'
261
+
262
+ def _create_batch_from_table(self, table: Dict, index: int, full_text: str) -> Optional[Dict]:
263
+ """Create a batch structure from extracted table data."""
264
+ context = table.get("context", "")
265
+
266
+ # Try to identify batch name from context
267
+ batch_name = self._extract_batch_name(context, full_text, index)
268
+
269
+ # Try to identify condition
270
+ condition_info = self._extract_condition(context)
271
+
272
+ # Build CQA data
273
+ cqa_list = []
274
+ for row in table.get("rows", []):
275
+ cqa_name = "总杂质" if row["type"] == "impurity" else (
276
+ "含量" if row["type"] == "assay" else "质量指标"
277
+ )
278
+ cqa_list.append({
279
+ "cqa_name": cqa_name,
280
+ "values": row["values"]
281
+ })
282
+
283
+ if not cqa_list:
284
+ return None
285
+
286
+ return {
287
+ "batch_id": batch_name.replace(" ", "_"),
288
+ "batch_name": batch_name,
289
+ "batch_type": "target",
290
+ "conditions": [{
291
+ "condition_id": condition_info["id"],
292
+ "condition_type": condition_info["type"],
293
+ "timepoints": table["times"],
294
+ "cqa_data": cqa_list
295
+ }]
296
+ }
297
+
298
+ def _extract_batch_name(self, context: str, full_text: str, index: int) -> str:
299
+ """Extract batch name from context using various patterns."""
300
+ patterns = [
301
+ r'批[次号][::\s]*([A-Za-z0-9\-_]+)',
302
+ r'Batch[::\s]*([A-Za-z0-9\-_]+)',
303
+ r'([A-Z]{2,3}[-_]\d{4,}[-_]?[A-Z0-9]*)', # Common batch ID format
304
+ r'([SF][-_]?\d{4}[-_]?\d+)', # SF-xxxx format
305
+ r'样品[::\s]*(.{3,20})',
306
+ ]
307
+
308
+ for pattern in patterns:
309
+ match = re.search(pattern, context, re.IGNORECASE)
310
+ if match:
311
+ name = match.group(1).strip()
312
+ if len(name) >= 3:
313
+ return name
314
+
315
+ # Fallback: use numbered batch
316
+ return f"批次{index + 1}"
317
+
318
+ def _extract_condition(self, context: str) -> Dict[str, str]:
319
+ """Extract storage condition from context."""
320
+ context_lower = context.lower()
321
+
322
+ # Check for specific conditions
323
+ if any(kw in context_lower for kw in ['40°c', '40℃', '40c', '加速']):
324
+ return {"id": "40C_Accelerated", "type": "accelerated"}
325
+ elif any(kw in context_lower for kw in ['25°c', '25℃', '25c', '长期']):
326
+ return {"id": "25C_LongTerm", "type": "longterm"}
327
+ elif any(kw in context_lower for kw in ['60°c', '60℃', '60c', '高温']):
328
+ return {"id": "60C_Stress", "type": "stress"}
329
+ elif any(kw in context_lower for kw in ['30°c', '30℃', '30c', '中间']):
330
+ return {"id": "30C_Intermediate", "type": "intermediate"}
331
+
332
+ return {"id": "Unknown_Condition", "type": "unknown"}
333
+
334
+ def _extract_target_timepoints(self, goal: str) -> List[int]:
335
+ """Extract target prediction timepoints from goal text."""
336
+ timepoints = []
337
+
338
+ patterns = [
339
+ r'(\d+)\s*[个]?月',
340
+ r'(\d+)\s*[Mm]',
341
+ r'(\d+)\s*months?'
342
+ ]
343
+
344
+ for pattern in patterns:
345
+ matches = re.findall(pattern, goal)
346
+ timepoints.extend([int(m) for m in matches])
347
+
348
+ timepoints = sorted(list(set(timepoints)))
349
+
350
+ if not timepoints:
351
+ timepoints = [24, 36]
352
+
353
+ return timepoints
354
+
355
+
356
+ # Convenience function for backward compatibility
357
+ def extract_stability_data(file_paths: List[str], goal: str) -> Dict[str, Any]:
358
+ """
359
+ Main entry point for data extraction.
360
+ """
361
+ from utils.file_parsers import parse_file
362
+
363
+ extractor = StabilityDataExtractor()
364
+ all_text = ""
365
+
366
+ for path in file_paths:
367
+ try:
368
+ content = parse_file(path)
369
+ if content:
370
+ all_text += f"\n=== File: {path} ===\n{content}\n"
371
+ except Exception as e:
372
+ print(f"Error parsing {path}: {e}")
373
+
374
+ return extractor.extract_from_text(all_text, goal)