Spaces:
Sleeping
Sleeping
improvements
Browse files- .gitattributes +0 -35
- README.md +49 -5
- configuration/parameters.py +19 -8
- content_analyzer/document_parser.py +18 -50
- content_analyzer/visual_detector.py +45 -28
- core/diagnostics.py +0 -125
- core/lifecycle.py +0 -160
- intelligence/accuracy_verifier.py +28 -10
- intelligence/orchestrator.py +63 -130
- main.py +107 -192
- packages.txt +1 -0
- requirements.txt +1 -1
- search_engine/indexer.py +3 -3
- tests/conftest.py +0 -71
- tests/test_accuracy_verifier.py +0 -110
- tests/test_context_validator.py +0 -120
- tests/test_knowledge_synthesizer.py +0 -50
- tests/test_visual_extraction.py +0 -169
- vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/data_level0.bin +0 -3
- vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/header.bin +0 -3
- vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/index_metadata.pickle +0 -3
- vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/length.bin +0 -3
- vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/link_lists.bin +0 -3
.gitattributes
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,6 +1,50 @@
|
|
| 1 |
# SmartDoc AI
|
| 2 |
|
| 3 |
-
SmartDoc AI is an advanced document analysis and question answering system
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
---
|
| 6 |
|
|
@@ -27,8 +71,8 @@ SmartDoc AI is an advanced document analysis and question answering system. It a
|
|
| 27 |
|
| 28 |
1. Clone the repository:
|
| 29 |
```bash
|
| 30 |
-
git clone https://github.com/TilanTAB/Intelligent-Document-Analysis-
|
| 31 |
-
cd Intelligent-Document-Analysis-
|
| 32 |
```
|
| 33 |
|
| 34 |
2. Activate the virtual environment:
|
|
@@ -43,7 +87,7 @@ activate_venv.bat
|
|
| 43 |
|
| 44 |
3. Install dependencies (if needed):
|
| 45 |
```bash
|
| 46 |
-
pip install -r
|
| 47 |
```
|
| 48 |
|
| 49 |
4. Configure environment variables:
|
|
@@ -114,4 +158,4 @@ This project is licensed under the MIT License.
|
|
| 114 |
|
| 115 |
---
|
| 116 |
|
| 117 |
-
SmartDoc AI is actively maintained and designed for real-world document analysis and Q&A. For updates and support, visit the [GitHub repository](https://github.com/TilanTAB/Intelligent-Document-Analysis-
|
|
|
|
| 1 |
# SmartDoc AI
|
| 2 |
|
| 3 |
+
SmartDoc AI is an advanced document analysis and question answering system, designed for source-grounded Q&A over complex business and scientific reports�especially where key evidence lives in tables and charts.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## ?? Personal Research Update
|
| 8 |
+
|
| 9 |
+
**SmartDoc AI � Document Q&A + Selective Chart Understanding**
|
| 10 |
+
|
| 11 |
+
I�ve been developing SmartDoc AI as a technical experiment to improve question answering over complex business/scientific reports�especially where key evidence lives in tables and charts.
|
| 12 |
+
|
| 13 |
+
### Technical highlights:
|
| 14 |
+
|
| 15 |
+
- **Multi-format ingestion:** PDF, DOCX, TXT, Markdown
|
| 16 |
+
- **LLM-assisted query decomposition:** breaks complex prompts into clearer sub-questions for retrieval + answering
|
| 17 |
+
- **Selective chart pipeline (cost-aware):**
|
| 18 |
+
- Local OpenCV heuristics flag pages that likely contain charts
|
| 19 |
+
- Gemini Vision is invoked only for chart pages to generate structured chart analysis (reduces unnecessary vision calls)
|
| 20 |
+
- **Table extraction + robust PDF parsing:** pdfplumber strategies for bordered and borderless tables
|
| 21 |
+
- **Parallelized processing:** concurrent PDF parsing + chart detection; batch chart analysis where enabled
|
| 22 |
+
- **Hybrid retrieval:** BM25 + vector search combined via an ensemble retriever
|
| 23 |
+
- **Multi-agent answering:** answer drafting + verification pass, with retrieved context available for inspection (page/source metadata)
|
| 24 |
+
|
| 25 |
+
**Runtime note:** Large PDFs (many pages/charts) can take minutes depending on DPI, chart volume, and available memory/CPU (HF Spaces limits can be a factor).
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## ?? Demo Videos
|
| 30 |
+
|
| 31 |
+
- [SmartDoc AI technical demo #1](https://youtu.be/uVU_sLiJU4w)
|
| 32 |
+
- [SmartDoc AI technical demo #2](https://youtu.be/c8CF7-OaKmQ)
|
| 33 |
+
- [SmartDoc AI technical demo #3](https://youtu.be/P17SZSQJ6Wc)
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Repository
|
| 38 |
+
?? https://github.com/TilanTAB/Intelligent-Document-Analysis-SmartDoc-AI
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## Use Cases
|
| 43 |
+
|
| 44 |
+
- Source-grounded Q&A for business/research documents
|
| 45 |
+
- Automated extraction and summarization from tables/charts
|
| 46 |
+
|
| 47 |
+
If you�re interested in architecture tradeoffs (cost, latency, memory limits, retrieval quality), feel free to connect.
|
| 48 |
|
| 49 |
---
|
| 50 |
|
|
|
|
| 71 |
|
| 72 |
1. Clone the repository:
|
| 73 |
```bash
|
| 74 |
+
git clone https://github.com/TilanTAB/Intelligent-Document-Analysis-SmartDoc-AI.git
|
| 75 |
+
cd Intelligent-Document-Analysis-SmartDoc-AI
|
| 76 |
```
|
| 77 |
|
| 78 |
2. Activate the virtual environment:
|
|
|
|
| 87 |
|
| 88 |
3. Install dependencies (if needed):
|
| 89 |
```bash
|
| 90 |
+
pip install -r requirements.txt
|
| 91 |
```
|
| 92 |
|
| 93 |
4. Configure environment variables:
|
|
|
|
| 158 |
|
| 159 |
---
|
| 160 |
|
| 161 |
+
SmartDoc AI is actively maintained and designed for real-world document analysis and Q&A. For updates and support, visit the [GitHub repository](https://github.com/TilanTAB/Intelligent-Document-Analysis-SmartDoc-AI).
|
configuration/parameters.py
CHANGED
|
@@ -5,6 +5,16 @@ import os
|
|
| 5 |
from .definitions import MAX_FILE_SIZE, MAX_TOTAL_SIZE, ALLOWED_TYPES
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
class Settings(BaseSettings):
|
| 9 |
"""
|
| 10 |
Application parameters loaded from environment variables.
|
|
@@ -35,7 +45,7 @@ class Settings(BaseSettings):
|
|
| 35 |
)
|
| 36 |
|
| 37 |
# Database parameters
|
| 38 |
-
CHROMA_DB_PATH: str =
|
| 39 |
|
| 40 |
# Chunking parameters
|
| 41 |
CHUNK_SIZE: int = 2000
|
|
@@ -51,18 +61,19 @@ class Settings(BaseSettings):
|
|
| 51 |
CHROMA_COLLECTION_NAME: str = "documents"
|
| 52 |
|
| 53 |
# Workflow parameters
|
| 54 |
-
MAX_RESEARCH_ATTEMPTS: int =
|
| 55 |
ENABLE_QUERY_REWRITING: bool = True
|
| 56 |
MAX_QUERY_REWRITES: int = 1
|
| 57 |
RELEVANCE_CHECK_K: int = 20
|
| 58 |
|
| 59 |
# Research agent parameters
|
| 60 |
RESEARCH_TOP_K: int = 15
|
| 61 |
-
RESEARCH_MAX_CONTEXT_CHARS: int = 8000000000
|
| 62 |
RESEARCH_MAX_OUTPUT_TOKENS: int = 500
|
|
|
|
| 63 |
|
| 64 |
# Verification parameters
|
| 65 |
-
VERIFICATION_MAX_CONTEXT_CHARS: int = 800000000
|
| 66 |
VERIFICATION_MAX_OUTPUT_TOKENS: int = 300
|
| 67 |
|
| 68 |
# Logging parameters
|
|
@@ -86,12 +97,12 @@ class Settings(BaseSettings):
|
|
| 86 |
ENABLE_CHART_EXTRACTION: bool = True
|
| 87 |
CHART_VISION_MODEL: str = "gemini-2.5-flash-lite"
|
| 88 |
CHART_MAX_TOKENS: int = 1500
|
| 89 |
-
CHART_DPI: int =
|
| 90 |
-
CHART_BATCH_SIZE: int =
|
| 91 |
-
CHART_MAX_IMAGE_SIZE: int =
|
| 92 |
|
| 93 |
# Local chart detection parameters (cost optimization)
|
| 94 |
-
CHART_USE_LOCAL_DETECTION: bool = True # Use OpenCV first (FREE)
|
| 95 |
CHART_MIN_CONFIDENCE: float = 0.4 # Only analyze charts with confidence > 40%
|
| 96 |
CHART_SKIP_GEMINI_DETECTION: bool = True # Skip Gemini for detection, only use for analysis
|
| 97 |
CHART_GEMINI_FALLBACK_ENABLED: bool = False # Optional: Use Gemini if local fails
|
|
|
|
| 5 |
from .definitions import MAX_FILE_SIZE, MAX_TOTAL_SIZE, ALLOWED_TYPES
|
| 6 |
|
| 7 |
|
| 8 |
+
def _default_chroma_path() -> str:
|
| 9 |
+
if os.environ.get("SPACE_ID"): # Hugging Face Spaces
|
| 10 |
+
return os.environ.get("CHROMA_DB_PATH", "/tmp/chroma_db")
|
| 11 |
+
return os.environ.get("CHROMA_DB_PATH", "./chroma_db")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _is_hf() -> bool:
|
| 15 |
+
return os.environ.get("SPACE_ID") is not None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
class Settings(BaseSettings):
|
| 19 |
"""
|
| 20 |
Application parameters loaded from environment variables.
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
# Database parameters
|
| 48 |
+
CHROMA_DB_PATH: str = Field(default_factory=_default_chroma_path)
|
| 49 |
|
| 50 |
# Chunking parameters
|
| 51 |
CHUNK_SIZE: int = 2000
|
|
|
|
| 61 |
CHROMA_COLLECTION_NAME: str = "documents"
|
| 62 |
|
| 63 |
# Workflow parameters
|
| 64 |
+
MAX_RESEARCH_ATTEMPTS: int = 5
|
| 65 |
ENABLE_QUERY_REWRITING: bool = True
|
| 66 |
MAX_QUERY_REWRITES: int = 1
|
| 67 |
RELEVANCE_CHECK_K: int = 20
|
| 68 |
|
| 69 |
# Research agent parameters
|
| 70 |
RESEARCH_TOP_K: int = 15
|
| 71 |
+
RESEARCH_MAX_CONTEXT_CHARS: int = Field(default_factory=lambda: 800_000 if _is_hf() else 8000000000)
|
| 72 |
RESEARCH_MAX_OUTPUT_TOKENS: int = 500
|
| 73 |
+
NUM_RESEARCH_CANDIDATES: int = 2 # Number of research questions to generate
|
| 74 |
|
| 75 |
# Verification parameters
|
| 76 |
+
VERIFICATION_MAX_CONTEXT_CHARS: int = Field(default_factory=lambda: 300_000 if _is_hf() else 800000000)
|
| 77 |
VERIFICATION_MAX_OUTPUT_TOKENS: int = 300
|
| 78 |
|
| 79 |
# Logging parameters
|
|
|
|
| 97 |
ENABLE_CHART_EXTRACTION: bool = True
|
| 98 |
CHART_VISION_MODEL: str = "gemini-2.5-flash-lite"
|
| 99 |
CHART_MAX_TOKENS: int = 1500
|
| 100 |
+
CHART_DPI: int = Field(default_factory=lambda: 110 if _is_hf() else 110) # Lower DPI saves memory
|
| 101 |
+
CHART_BATCH_SIZE: int = Field(default_factory=lambda: 1 if _is_hf() else 1) # Process pages in batches
|
| 102 |
+
CHART_MAX_IMAGE_SIZE: int = Field(default_factory=lambda: 1200 if _is_hf() else 1200) # Max dimension for images
|
| 103 |
|
| 104 |
# Local chart detection parameters (cost optimization)
|
| 105 |
+
CHART_USE_LOCAL_DETECTION: bool = Field(default_factory=lambda: True if _is_hf() else True) # Use OpenCV first (FREE)
|
| 106 |
CHART_MIN_CONFIDENCE: float = 0.4 # Only analyze charts with confidence > 40%
|
| 107 |
CHART_SKIP_GEMINI_DETECTION: bool = True # Skip Gemini for detection, only use for analysis
|
| 108 |
CHART_GEMINI_FALLBACK_ENABLED: bool = False # Optional: Use Gemini if local fails
|
content_analyzer/document_parser.py
CHANGED
|
@@ -33,7 +33,6 @@ def detect_chart_on_page(args):
|
|
| 33 |
# Downscale image before detection to save memory
|
| 34 |
image = preprocess_image(image, max_dim=1000)
|
| 35 |
detection_result = LocalChartDetector.detect_charts(image)
|
| 36 |
-
# Do NOT delete image here; it will be saved in the main process
|
| 37 |
return (page_num, image, detection_result)
|
| 38 |
|
| 39 |
def analyze_batch(batch_tuple):
|
|
@@ -276,7 +275,7 @@ class DocumentProcessor:
|
|
| 276 |
except Exception as e:
|
| 277 |
logger.error(f"Failed to save cache to {cache_path.name}: {e}", exc_info=True)
|
| 278 |
|
| 279 |
-
def _process_file(self, file
|
| 280 |
file_ext = Path(file.name).suffix.lower()
|
| 281 |
if file_ext not in ALLOWED_TYPES:
|
| 282 |
logger.warning(f"Skipping unsupported file type: {file.name}")
|
|
@@ -341,26 +340,27 @@ class DocumentProcessor:
|
|
| 341 |
return []
|
| 342 |
all_chunks = []
|
| 343 |
total_docs = len(documents)
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
for i, doc in enumerate(documents):
|
| 346 |
page_chunks = self.splitter.split_text(doc.page_content)
|
| 347 |
total_chunks = len(page_chunks)
|
| 348 |
for j, chunk in enumerate(page_chunks):
|
| 349 |
-
chunk_id = f"{file_hash}_{doc.metadata.get('page', i + 1)}_{j}"
|
| 350 |
chunk_doc = Document(
|
| 351 |
page_content=chunk,
|
| 352 |
metadata={
|
| 353 |
-
"source":
|
| 354 |
"page": doc.metadata.get("page", i + 1),
|
| 355 |
"type": doc.metadata.get("type", "text"),
|
| 356 |
"chunk_id": chunk_id
|
| 357 |
}
|
| 358 |
)
|
| 359 |
all_chunks.append(chunk_doc)
|
| 360 |
-
|
| 361 |
-
percent = int(100 * ((i + (j + 1) / total_chunks) / total_docs))
|
| 362 |
-
step = f"Splitting page {i+1} into chunks"
|
| 363 |
-
progress_callback(percent, step)
|
| 364 |
logger.info(f"Processed {file.name}: {len(documents)} page(s) → {len(all_chunks)} chunk(s)")
|
| 365 |
return all_chunks
|
| 366 |
except ImportError as e:
|
|
@@ -376,7 +376,9 @@ class DocumentProcessor:
|
|
| 376 |
PHASE 1: Parallel local chart detection (CPU-bound, uses ProcessPoolExecutor)
|
| 377 |
PHASE 2: Parallel Gemini batch analysis (I/O-bound, uses ThreadPoolExecutor)
|
| 378 |
"""
|
| 379 |
-
|
|
|
|
|
|
|
| 380 |
def deduplicate_charts_by_title(chart_chunks):
|
| 381 |
seen_titles = set()
|
| 382 |
unique_chunks = []
|
|
@@ -632,7 +634,9 @@ class DocumentProcessor:
|
|
| 632 |
import pdfplumber
|
| 633 |
|
| 634 |
logger.info(f"[PDFPLUMBER] Processing: {file_path}")
|
| 635 |
-
|
|
|
|
|
|
|
| 636 |
|
| 637 |
# Strategy 1: Line-based (default) - for tables with visible borders
|
| 638 |
default_parameters = {}
|
|
@@ -733,11 +737,11 @@ class DocumentProcessor:
|
|
| 733 |
|
| 734 |
if len(page_content) > 1:
|
| 735 |
combined = "\n\n".join(page_content)
|
| 736 |
-
chunk_id = f"{file_hash}_{page_num}_0"
|
| 737 |
doc = Document(
|
| 738 |
page_content=combined,
|
| 739 |
metadata={
|
| 740 |
-
"source":
|
| 741 |
"page": page_num,
|
| 742 |
"loader": "pdfplumber",
|
| 743 |
"tables_count": total_tables,
|
|
@@ -789,43 +793,7 @@ class DocumentProcessor:
|
|
| 789 |
for row in cleaned_table[1:]:
|
| 790 |
md_lines.append("| " + " | ".join(row) + " |")
|
| 791 |
|
| 792 |
-
return "\n".join(md_lines)
|
| 793 |
-
|
| 794 |
-
def process(self, files: List, progress_callback=None) -> List[Document]:
|
| 795 |
-
"""
|
| 796 |
-
Process multiple files with caching and deduplication.
|
| 797 |
-
"""
|
| 798 |
-
self.validate_files(files)
|
| 799 |
-
all_chunks = []
|
| 800 |
-
seen_hashes = set()
|
| 801 |
-
logger.info(f"Processing {len(files)} file(s)...")
|
| 802 |
-
for file in files:
|
| 803 |
-
try:
|
| 804 |
-
with open(file.name, 'rb') as f:
|
| 805 |
-
file_content = f.read()
|
| 806 |
-
file_hash = self._generate_hash(file_content)
|
| 807 |
-
cache_path = self.cache_dir / f"{file_hash}.pkl"
|
| 808 |
-
if self._is_cache_valid(cache_path):
|
| 809 |
-
chunks = self._load_from_cache(cache_path)
|
| 810 |
-
if chunks:
|
| 811 |
-
logger.info(f"Using cached chunks for {file.name}")
|
| 812 |
-
else:
|
| 813 |
-
chunks = self._process_file(file, progress_callback=progress_callback)
|
| 814 |
-
self._save_to_cache(chunks, cache_path)
|
| 815 |
-
else:
|
| 816 |
-
logger.info(f"Processing and caching: {file.name}")
|
| 817 |
-
chunks = self._process_file(file, progress_callback=progress_callback)
|
| 818 |
-
self._save_to_cache(chunks, cache_path)
|
| 819 |
-
for chunk in chunks:
|
| 820 |
-
chunk_hash = self._generate_hash(chunk.page_content.encode())
|
| 821 |
-
if chunk_hash not in seen_hashes:
|
| 822 |
-
seen_hashes.add(chunk_hash)
|
| 823 |
-
all_chunks.append(chunk)
|
| 824 |
-
except Exception as e:
|
| 825 |
-
logger.error(f"Failed to process {file.name}: {e}", exc_info=True)
|
| 826 |
-
continue
|
| 827 |
-
logger.info(f"Processing complete: {len(all_chunks)} unique chunks from {len(files)} file(s)")
|
| 828 |
-
return all_chunks
|
| 829 |
|
| 830 |
def run_pdfplumber(file_name):
|
| 831 |
from content_analyzer.document_parser import DocumentProcessor
|
|
|
|
| 33 |
# Downscale image before detection to save memory
|
| 34 |
image = preprocess_image(image, max_dim=1000)
|
| 35 |
detection_result = LocalChartDetector.detect_charts(image)
|
|
|
|
| 36 |
return (page_num, image, detection_result)
|
| 37 |
|
| 38 |
def analyze_batch(batch_tuple):
|
|
|
|
| 275 |
except Exception as e:
|
| 276 |
logger.error(f"Failed to save cache to {cache_path.name}: {e}", exc_info=True)
|
| 277 |
|
| 278 |
+
def _process_file(self, file) -> List[Document]:
|
| 279 |
file_ext = Path(file.name).suffix.lower()
|
| 280 |
if file_ext not in ALLOWED_TYPES:
|
| 281 |
logger.warning(f"Skipping unsupported file type: {file.name}")
|
|
|
|
| 340 |
return []
|
| 341 |
all_chunks = []
|
| 342 |
total_docs = len(documents)
|
| 343 |
+
# --- STABLE FILE HASHING ---
|
| 344 |
+
with open(file.name, 'rb') as f:
|
| 345 |
+
file_bytes = f.read()
|
| 346 |
+
file_hash = self._generate_hash(file_bytes) # Stable hash by file content
|
| 347 |
+
stable_source = f"{Path(file.name).name}::{file_hash}"
|
| 348 |
for i, doc in enumerate(documents):
|
| 349 |
page_chunks = self.splitter.split_text(doc.page_content)
|
| 350 |
total_chunks = len(page_chunks)
|
| 351 |
for j, chunk in enumerate(page_chunks):
|
| 352 |
+
chunk_id = f"txt_{file_hash}_{doc.metadata.get('page', i + 1)}_{j}"
|
| 353 |
chunk_doc = Document(
|
| 354 |
page_content=chunk,
|
| 355 |
metadata={
|
| 356 |
+
"source": stable_source,
|
| 357 |
"page": doc.metadata.get("page", i + 1),
|
| 358 |
"type": doc.metadata.get("type", "text"),
|
| 359 |
"chunk_id": chunk_id
|
| 360 |
}
|
| 361 |
)
|
| 362 |
all_chunks.append(chunk_doc)
|
| 363 |
+
|
|
|
|
|
|
|
|
|
|
| 364 |
logger.info(f"Processed {file.name}: {len(documents)} page(s) → {len(all_chunks)} chunk(s)")
|
| 365 |
return all_chunks
|
| 366 |
except ImportError as e:
|
|
|
|
| 376 |
PHASE 1: Parallel local chart detection (CPU-bound, uses ProcessPoolExecutor)
|
| 377 |
PHASE 2: Parallel Gemini batch analysis (I/O-bound, uses ThreadPoolExecutor)
|
| 378 |
"""
|
| 379 |
+
file_bytes = Path(file_path).read_bytes()
|
| 380 |
+
file_hash = self._generate_hash(file_bytes)
|
| 381 |
+
stable_source = f"{Path(file_path).name}::{file_hash}"
|
| 382 |
def deduplicate_charts_by_title(chart_chunks):
|
| 383 |
seen_titles = set()
|
| 384 |
unique_chunks = []
|
|
|
|
| 634 |
import pdfplumber
|
| 635 |
|
| 636 |
logger.info(f"[PDFPLUMBER] Processing: {file_path}")
|
| 637 |
+
file_bytes = Path(file_path).read_bytes()
|
| 638 |
+
file_hash = self._generate_hash(file_bytes)
|
| 639 |
+
stable_source = f"{Path(file_path).name}::{file_hash}"
|
| 640 |
|
| 641 |
# Strategy 1: Line-based (default) - for tables with visible borders
|
| 642 |
default_parameters = {}
|
|
|
|
| 737 |
|
| 738 |
if len(page_content) > 1:
|
| 739 |
combined = "\n\n".join(page_content)
|
| 740 |
+
chunk_id = f"txt_{file_hash}_{page_num}_0"
|
| 741 |
doc = Document(
|
| 742 |
page_content=combined,
|
| 743 |
metadata={
|
| 744 |
+
"source": stable_source,
|
| 745 |
"page": page_num,
|
| 746 |
"loader": "pdfplumber",
|
| 747 |
"tables_count": total_tables,
|
|
|
|
| 793 |
for row in cleaned_table[1:]:
|
| 794 |
md_lines.append("| " + " | ".join(row) + " |")
|
| 795 |
|
| 796 |
+
return "\n".join(md_lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
|
| 798 |
def run_pdfplumber(file_name):
|
| 799 |
from content_analyzer.document_parser import DocumentProcessor
|
content_analyzer/visual_detector.py
CHANGED
|
@@ -52,8 +52,22 @@ class LocalChartDetector:
|
|
| 52 |
else:
|
| 53 |
image_cv = image
|
| 54 |
height, width = image_cv.shape[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# --- Edge Detection ---
|
| 58 |
edges = cv2.Canny(gray, 50, 150)
|
| 59 |
|
|
@@ -74,23 +88,24 @@ class LocalChartDetector:
|
|
| 74 |
edges,
|
| 75 |
rho=1,
|
| 76 |
theta=np.pi/180,
|
| 77 |
-
threshold=
|
| 78 |
minLineLength=100,
|
| 79 |
maxLineGap=10
|
| 80 |
)
|
| 81 |
line_count = len(lines) if lines is not None else 0
|
| 82 |
-
|
| 83 |
-
|
| 84 |
if lines is not None:
|
| 85 |
for line in lines:
|
| 86 |
x1, y1, x2, y2 = line[0]
|
| 87 |
angle = np.abs(np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi)
|
| 88 |
if 10 < angle < 80 or 100 < angle < 170:
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
| 91 |
|
| 92 |
# --- Circle Detection (Optimized) ---
|
| 93 |
-
run_circles = diagonal_lines >= 1 or line_count >= 6 or overall_edge_density > 0.08
|
| 94 |
circle_count = 0
|
| 95 |
circles = None
|
| 96 |
if run_circles:
|
|
@@ -110,12 +125,13 @@ class LocalChartDetector:
|
|
| 110 |
circle_count = circles.shape[2]
|
| 111 |
|
| 112 |
# --- Color Diversity Analysis ---
|
| 113 |
-
|
|
|
|
| 114 |
hist = cv2.calcHist([hsv], [0], None, [180], [0, 180])
|
| 115 |
-
color_peaks = np.sum(hist > np.mean(hist) * 2)
|
| 116 |
|
| 117 |
# --- Contour Detection ---
|
| 118 |
-
contours, _ = cv2.findContours(edges, cv2.
|
| 119 |
significant_contours = 0
|
| 120 |
rectangle_contours = 0
|
| 121 |
similar_rectangles = []
|
|
@@ -148,16 +164,16 @@ class LocalChartDetector:
|
|
| 148 |
if (width_std < avg_width * 0.3 or height_std < avg_height * 0.3):
|
| 149 |
bar_pattern = True
|
| 150 |
|
| 151 |
-
# --- Line Classification ---
|
| 152 |
horizontal_lines = 0
|
| 153 |
vertical_lines = 0
|
| 154 |
-
|
| 155 |
line_angles = []
|
| 156 |
very_short_lines = 0
|
| 157 |
if lines is not None:
|
| 158 |
for line in lines:
|
| 159 |
x1, y1, x2, y2 = line[0]
|
| 160 |
-
length = np.
|
| 161 |
if length < 50:
|
| 162 |
very_short_lines += 1
|
| 163 |
continue
|
|
@@ -170,11 +186,11 @@ class LocalChartDetector:
|
|
| 170 |
elif 80 < angle < 100:
|
| 171 |
vertical_lines += 1
|
| 172 |
else:
|
| 173 |
-
|
| 174 |
-
angle_variance = np.var(line_angles) if len(line_angles) > 2 else 0
|
| 175 |
|
| 176 |
# --- Debug Logging ---
|
| 177 |
-
logger.debug(f"Chart detection features: lines={line_count},
|
| 178 |
|
| 179 |
# --- Chart Heuristics and Classification ---
|
| 180 |
chart_types = []
|
|
@@ -183,16 +199,16 @@ class LocalChartDetector:
|
|
| 183 |
rejection_reason = ""
|
| 184 |
|
| 185 |
# Negative checks (text slides, decorative backgrounds, tables)
|
| 186 |
-
if has_text_region and circle_count < 2 and
|
| 187 |
if small_scattered_contours > 100 or very_short_lines > 50:
|
| 188 |
rejection_reason = f"Text slide with decorative background (overall density: {overall_edge_density:.2%})"
|
| 189 |
logger.debug(f"Rejected: {rejection_reason}")
|
| 190 |
return _chart_result(False, 0.0, [], rejection_reason, line_count, circle_count, overall_edge_density)
|
| 191 |
-
if very_short_lines > 50 and circle_count < 2 and
|
| 192 |
rejection_reason = f"Decorative network background ({very_short_lines} tiny lines, no data elements)"
|
| 193 |
logger.debug(f"Rejected: {rejection_reason}")
|
| 194 |
return _chart_result(False, 0.0, [], rejection_reason, line_count, circle_count, overall_edge_density)
|
| 195 |
-
if horizontal_lines > 12 and vertical_lines > 12 and circle_count == 0 and
|
| 196 |
grid_lines = horizontal_lines + vertical_lines
|
| 197 |
total_lines = line_count
|
| 198 |
grid_ratio = grid_lines / max(total_lines, 1)
|
|
@@ -204,15 +220,15 @@ class LocalChartDetector:
|
|
| 204 |
# Positive chart heuristics (bubble, scatter, line, pie, bar, complex)
|
| 205 |
# RELAXED: Detect as line chart if 2+ diagonal lines and angle variance > 40, or 1+ diagonal line and 1+ axis
|
| 206 |
if (
|
| 207 |
-
(
|
| 208 |
-
(
|
| 209 |
):
|
| 210 |
chart_types.append("line_chart")
|
| 211 |
-
confidence = max(confidence, min(0.88, 0.6 + (
|
| 212 |
if (horizontal_lines >= 1 or vertical_lines >= 1):
|
| 213 |
confidence = min(0.95, confidence + 0.08)
|
| 214 |
if not description:
|
| 215 |
-
description = f"Line chart: {
|
| 216 |
if circle_count >= 5:
|
| 217 |
chart_types.append("bubble_chart")
|
| 218 |
confidence = min(0.92, 0.70 + (min(circle_count, 20) * 0.01))
|
|
@@ -224,7 +240,7 @@ class LocalChartDetector:
|
|
| 224 |
confidence = min(0.97, confidence + 0.05)
|
| 225 |
chart_types.append("zone_diagram")
|
| 226 |
description += f", {large_contours} colored regions"
|
| 227 |
-
elif circle_count >= 3 and
|
| 228 |
chart_types.append("scatter_plot")
|
| 229 |
confidence = max(confidence, 0.75)
|
| 230 |
description = f"Scatter plot: {circle_count} data points"
|
|
@@ -245,7 +261,7 @@ class LocalChartDetector:
|
|
| 245 |
if not description:
|
| 246 |
description = "Complex visualization with zones and data points"
|
| 247 |
has_moderate_axes = (1 <= horizontal_lines <= 6 or 1 <= vertical_lines <= 6)
|
| 248 |
-
has_real_data = (circle_count >= 3 or
|
| 249 |
if has_moderate_axes and has_real_data and confidence > 0.3:
|
| 250 |
confidence = min(0.90, confidence + 0.10)
|
| 251 |
if not description:
|
|
@@ -253,8 +269,8 @@ class LocalChartDetector:
|
|
| 253 |
|
| 254 |
# Final chart determination
|
| 255 |
strong_indicator = (
|
| 256 |
-
(
|
| 257 |
-
(
|
| 258 |
circle_count >= 5 or
|
| 259 |
(circle_count >= 3 and large_contours >= 2) or
|
| 260 |
bar_pattern or
|
|
@@ -267,7 +283,7 @@ class LocalChartDetector:
|
|
| 267 |
)
|
| 268 |
total_time = time.time() - start_time
|
| 269 |
if has_chart:
|
| 270 |
-
logger.info(f"?? OpenCV detection: {total_time*1000:.0f}ms (lines:{line_count},
|
| 271 |
else:
|
| 272 |
logger.debug(f"?? OpenCV detection: {total_time*1000:.0f}ms (rejected)")
|
| 273 |
return {
|
|
@@ -277,7 +293,8 @@ class LocalChartDetector:
|
|
| 277 |
'description': description or "Potential chart detected",
|
| 278 |
'features': {
|
| 279 |
'lines': line_count,
|
| 280 |
-
'
|
|
|
|
| 281 |
'circles': circle_count,
|
| 282 |
'contours': significant_contours,
|
| 283 |
'rectangles': rectangle_contours,
|
|
|
|
| 52 |
else:
|
| 53 |
image_cv = image
|
| 54 |
height, width = image_cv.shape[:2]
|
| 55 |
+
|
| 56 |
+
# Always downscale for detection (even if caller forgot)
|
| 57 |
+
MAX_DETECT_DIM = 900
|
| 58 |
+
if max(height, width) > MAX_DETECT_DIM:
|
| 59 |
+
scale = MAX_DETECT_DIM / max(height, width)
|
| 60 |
+
image_cv = cv2.resize(image_cv, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
|
| 61 |
+
height, width = image_cv.shape[:2]
|
| 62 |
+
|
| 63 |
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
|
| 64 |
|
| 65 |
+
# Optional: reduce OpenCV internal thread usage (helps in HF containers)
|
| 66 |
+
try:
|
| 67 |
+
cv2.setNumThreads(1)
|
| 68 |
+
except Exception:
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
# --- Edge Detection ---
|
| 72 |
edges = cv2.Canny(gray, 50, 150)
|
| 73 |
|
|
|
|
| 88 |
edges,
|
| 89 |
rho=1,
|
| 90 |
theta=np.pi/180,
|
| 91 |
+
threshold=120, # slightly higher reduces line explosion
|
| 92 |
minLineLength=100,
|
| 93 |
maxLineGap=10
|
| 94 |
)
|
| 95 |
line_count = len(lines) if lines is not None else 0
|
| 96 |
+
diag_lines_raw = 0
|
| 97 |
+
raw_angles = []
|
| 98 |
if lines is not None:
|
| 99 |
for line in lines:
|
| 100 |
x1, y1, x2, y2 = line[0]
|
| 101 |
angle = np.abs(np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi)
|
| 102 |
if 10 < angle < 80 or 100 < angle < 170:
|
| 103 |
+
diag_lines_raw += 1
|
| 104 |
+
raw_angles.append(angle)
|
| 105 |
+
|
| 106 |
+
run_circles = diag_lines_raw >= 1 or line_count >= 6
|
| 107 |
|
| 108 |
# --- Circle Detection (Optimized) ---
|
|
|
|
| 109 |
circle_count = 0
|
| 110 |
circles = None
|
| 111 |
if run_circles:
|
|
|
|
| 125 |
circle_count = circles.shape[2]
|
| 126 |
|
| 127 |
# --- Color Diversity Analysis ---
|
| 128 |
+
small_for_hist = cv2.resize(image_cv, (256, 256), interpolation=cv2.INTER_AREA)
|
| 129 |
+
hsv = cv2.cvtColor(small_for_hist, cv2.COLOR_BGR2HSV)
|
| 130 |
hist = cv2.calcHist([hsv], [0], None, [180], [0, 180])
|
| 131 |
+
color_peaks = int(np.sum(hist > (np.mean(hist) * 2)))
|
| 132 |
|
| 133 |
# --- Contour Detection ---
|
| 134 |
+
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 135 |
significant_contours = 0
|
| 136 |
rectangle_contours = 0
|
| 137 |
similar_rectangles = []
|
|
|
|
| 164 |
if (width_std < avg_width * 0.3 or height_std < avg_height * 0.3):
|
| 165 |
bar_pattern = True
|
| 166 |
|
| 167 |
+
# --- Line Classification (filtered) ---
|
| 168 |
horizontal_lines = 0
|
| 169 |
vertical_lines = 0
|
| 170 |
+
diag_lines_filtered = 0
|
| 171 |
line_angles = []
|
| 172 |
very_short_lines = 0
|
| 173 |
if lines is not None:
|
| 174 |
for line in lines:
|
| 175 |
x1, y1, x2, y2 = line[0]
|
| 176 |
+
length = np.hypot(x2 - x1, y2 - y1)
|
| 177 |
if length < 50:
|
| 178 |
very_short_lines += 1
|
| 179 |
continue
|
|
|
|
| 186 |
elif 80 < angle < 100:
|
| 187 |
vertical_lines += 1
|
| 188 |
else:
|
| 189 |
+
diag_lines_filtered += 1
|
| 190 |
+
angle_variance = float(np.var(line_angles)) if len(line_angles) > 2 else 0.0
|
| 191 |
|
| 192 |
# --- Debug Logging ---
|
| 193 |
+
logger.debug(f"Chart detection features: lines={line_count}, diag_lines_raw={diag_lines_raw}, diag_lines_filtered={diag_lines_filtered}, circles={circle_count}, horizontal_lines={horizontal_lines}, vertical_lines={vertical_lines}, color_peaks={color_peaks}, angle_variance={angle_variance}")
|
| 194 |
|
| 195 |
# --- Chart Heuristics and Classification ---
|
| 196 |
chart_types = []
|
|
|
|
| 199 |
rejection_reason = ""
|
| 200 |
|
| 201 |
# Negative checks (text slides, decorative backgrounds, tables)
|
| 202 |
+
if has_text_region and circle_count < 2 and diag_lines_filtered < 2 and not bar_pattern:
|
| 203 |
if small_scattered_contours > 100 or very_short_lines > 50:
|
| 204 |
rejection_reason = f"Text slide with decorative background (overall density: {overall_edge_density:.2%})"
|
| 205 |
logger.debug(f"Rejected: {rejection_reason}")
|
| 206 |
return _chart_result(False, 0.0, [], rejection_reason, line_count, circle_count, overall_edge_density)
|
| 207 |
+
if very_short_lines > 50 and circle_count < 2 and diag_lines_filtered < 3 and line_count < 10:
|
| 208 |
rejection_reason = f"Decorative network background ({very_short_lines} tiny lines, no data elements)"
|
| 209 |
logger.debug(f"Rejected: {rejection_reason}")
|
| 210 |
return _chart_result(False, 0.0, [], rejection_reason, line_count, circle_count, overall_edge_density)
|
| 211 |
+
if horizontal_lines > 12 and vertical_lines > 12 and circle_count == 0 and diag_lines_filtered < 2:
|
| 212 |
grid_lines = horizontal_lines + vertical_lines
|
| 213 |
total_lines = line_count
|
| 214 |
grid_ratio = grid_lines / max(total_lines, 1)
|
|
|
|
| 220 |
# Positive chart heuristics (bubble, scatter, line, pie, bar, complex)
|
| 221 |
# RELAXED: Detect as line chart if 2+ diagonal lines and angle variance > 40, or 1+ diagonal line and 1+ axis
|
| 222 |
if (
|
| 223 |
+
(diag_lines_filtered >= 2 and angle_variance > 40) or
|
| 224 |
+
(diag_lines_filtered >= 1 and (horizontal_lines >= 1 or vertical_lines >= 1))
|
| 225 |
):
|
| 226 |
chart_types.append("line_chart")
|
| 227 |
+
confidence = max(confidence, min(0.88, 0.6 + (diag_lines_filtered / 40)))
|
| 228 |
if (horizontal_lines >= 1 or vertical_lines >= 1):
|
| 229 |
confidence = min(0.95, confidence + 0.08)
|
| 230 |
if not description:
|
| 231 |
+
description = f"Line chart: {diag_lines_filtered} diagonal lines, axes: {horizontal_lines+vertical_lines}, variance: {angle_variance:.0f}"
|
| 232 |
if circle_count >= 5:
|
| 233 |
chart_types.append("bubble_chart")
|
| 234 |
confidence = min(0.92, 0.70 + (min(circle_count, 20) * 0.01))
|
|
|
|
| 240 |
confidence = min(0.97, confidence + 0.05)
|
| 241 |
chart_types.append("zone_diagram")
|
| 242 |
description += f", {large_contours} colored regions"
|
| 243 |
+
elif circle_count >= 3 and diag_lines_filtered > 2:
|
| 244 |
chart_types.append("scatter_plot")
|
| 245 |
confidence = max(confidence, 0.75)
|
| 246 |
description = f"Scatter plot: {circle_count} data points"
|
|
|
|
| 261 |
if not description:
|
| 262 |
description = "Complex visualization with zones and data points"
|
| 263 |
has_moderate_axes = (1 <= horizontal_lines <= 6 or 1 <= vertical_lines <= 6)
|
| 264 |
+
has_real_data = (circle_count >= 3 or diag_lines_filtered >= 2 or bar_pattern)
|
| 265 |
if has_moderate_axes and has_real_data and confidence > 0.3:
|
| 266 |
confidence = min(0.90, confidence + 0.10)
|
| 267 |
if not description:
|
|
|
|
| 269 |
|
| 270 |
# Final chart determination
|
| 271 |
strong_indicator = (
|
| 272 |
+
(diag_lines_filtered >= 2 and angle_variance > 40) or
|
| 273 |
+
(diag_lines_filtered >= 1 and (horizontal_lines >= 1 or vertical_lines >= 1)) or
|
| 274 |
circle_count >= 5 or
|
| 275 |
(circle_count >= 3 and large_contours >= 2) or
|
| 276 |
bar_pattern or
|
|
|
|
| 283 |
)
|
| 284 |
total_time = time.time() - start_time
|
| 285 |
if has_chart:
|
| 286 |
+
logger.info(f"?? OpenCV detection: {total_time*1000:.0f}ms (lines:{line_count}, diag_lines_filtered:{diag_lines_filtered}, circles:{circle_count}, axes:{horizontal_lines+vertical_lines}, angle_variance:{angle_variance})")
|
| 287 |
else:
|
| 288 |
logger.debug(f"?? OpenCV detection: {total_time*1000:.0f}ms (rejected)")
|
| 289 |
return {
|
|
|
|
| 293 |
'description': description or "Potential chart detected",
|
| 294 |
'features': {
|
| 295 |
'lines': line_count,
|
| 296 |
+
'diagonal_lines_raw': diag_lines_raw,
|
| 297 |
+
'diagonal_lines_filtered': diag_lines_filtered,
|
| 298 |
'circles': circle_count,
|
| 299 |
'contours': significant_contours,
|
| 300 |
'rectangles': rectangle_contours,
|
core/diagnostics.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Health check utilities for DocChat.
|
| 3 |
-
|
| 4 |
-
This module provides diagnostics check functions that can be used
|
| 5 |
-
to verify the application is running correctly.
|
| 6 |
-
"""
|
| 7 |
-
import logging
|
| 8 |
-
from typing import Dict, Any
|
| 9 |
-
from datetime import datetime
|
| 10 |
-
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def check_diagnostics() -> Dict[str, Any]:
|
| 15 |
-
"""
|
| 16 |
-
Perform a comprehensive diagnostics check of the application.
|
| 17 |
-
|
| 18 |
-
Returns:
|
| 19 |
-
Dict with diagnostics status and component information
|
| 20 |
-
"""
|
| 21 |
-
diagnostics_status = {
|
| 22 |
-
"status": "diagnosticsy",
|
| 23 |
-
"timestamp": datetime.utcnow().isoformat(),
|
| 24 |
-
"components": {}
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
# Check parameters
|
| 28 |
-
try:
|
| 29 |
-
from configuration.parameters import parameters
|
| 30 |
-
diagnostics_status["components"]["parameters"] = {
|
| 31 |
-
"status": "ok",
|
| 32 |
-
"chroma_db_path": parameters.CHROMA_DB_PATH,
|
| 33 |
-
"log_level": parameters.LOG_LEVEL
|
| 34 |
-
}
|
| 35 |
-
except Exception as e:
|
| 36 |
-
diagnostics_status["components"]["parameters"] = {
|
| 37 |
-
"status": "error",
|
| 38 |
-
"error": str(e)
|
| 39 |
-
}
|
| 40 |
-
diagnostics_status["status"] = "undiagnosticsy"
|
| 41 |
-
|
| 42 |
-
# Check ChromaDB directory
|
| 43 |
-
try:
|
| 44 |
-
from pathlib import Path
|
| 45 |
-
chroma_path = Path(parameters.CHROMA_DB_PATH)
|
| 46 |
-
diagnostics_status["components"]["chroma_db"] = {
|
| 47 |
-
"status": "ok",
|
| 48 |
-
"path_exists": chroma_path.exists(),
|
| 49 |
-
"is_writable": chroma_path.exists() and chroma_path.is_dir()
|
| 50 |
-
}
|
| 51 |
-
except Exception as e:
|
| 52 |
-
diagnostics_status["components"]["chroma_db"] = {
|
| 53 |
-
"status": "error",
|
| 54 |
-
"error": str(e)
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
# Check cache directory
|
| 58 |
-
try:
|
| 59 |
-
cache_path = Path(parameters.CACHE_DIR)
|
| 60 |
-
diagnostics_status["components"]["cache"] = {
|
| 61 |
-
"status": "ok",
|
| 62 |
-
"path_exists": cache_path.exists(),
|
| 63 |
-
"is_writable": cache_path.exists() and cache_path.is_dir()
|
| 64 |
-
}
|
| 65 |
-
except Exception as e:
|
| 66 |
-
diagnostics_status["components"]["cache"] = {
|
| 67 |
-
"status": "error",
|
| 68 |
-
"error": str(e)
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
# Check if required packages are importable
|
| 72 |
-
required_packages = [
|
| 73 |
-
"langchain",
|
| 74 |
-
"langchain_google_genai",
|
| 75 |
-
"chromadb",
|
| 76 |
-
"gradio"
|
| 77 |
-
]
|
| 78 |
-
|
| 79 |
-
packages_status = {}
|
| 80 |
-
for package in required_packages:
|
| 81 |
-
try:
|
| 82 |
-
__import__(package)
|
| 83 |
-
packages_status[package] = "ok"
|
| 84 |
-
except ImportError as e:
|
| 85 |
-
packages_status[package] = f"missing: {e}"
|
| 86 |
-
diagnostics_status["status"] = "degraded"
|
| 87 |
-
|
| 88 |
-
diagnostics_status["components"]["packages"] = packages_status
|
| 89 |
-
|
| 90 |
-
return diagnostics_status
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def check_api_key() -> Dict[str, Any]:
|
| 94 |
-
"""
|
| 95 |
-
Check if the Google API key is configured and valid format.
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
Dict with API key status (does not expose the key)
|
| 99 |
-
"""
|
| 100 |
-
try:
|
| 101 |
-
from configuration.parameters import parameters
|
| 102 |
-
api_key = parameters.GOOGLE_API_KEY
|
| 103 |
-
|
| 104 |
-
if not api_key:
|
| 105 |
-
return {"status": "missing", "message": "GOOGLE_API_KEY not set"}
|
| 106 |
-
|
| 107 |
-
if len(api_key) < 20:
|
| 108 |
-
return {"status": "invalid", "message": "API key appears too short"}
|
| 109 |
-
|
| 110 |
-
# Mask the key for logging (show first 4 and last 4 chars)
|
| 111 |
-
masked = f"{api_key[:4]}...{api_key[-4:]}"
|
| 112 |
-
|
| 113 |
-
return {
|
| 114 |
-
"status": "configured",
|
| 115 |
-
"masked_key": masked,
|
| 116 |
-
"length": len(api_key)
|
| 117 |
-
}
|
| 118 |
-
except Exception as e:
|
| 119 |
-
return {"status": "error", "message": str(e)}
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
if __name__ == "__main__":
|
| 123 |
-
# Run diagnostics check when executed directly
|
| 124 |
-
import json
|
| 125 |
-
print(json.dumps(check_diagnostics(), indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/lifecycle.py
DELETED
|
@@ -1,160 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Signal handling and graceful lifecycle utilities.
|
| 3 |
-
|
| 4 |
-
This module provides graceful lifecycle handling for the DocChat application,
|
| 5 |
-
ensuring resources are properly cleaned up when the application is terminated.
|
| 6 |
-
"""
|
| 7 |
-
import signal
|
| 8 |
-
import sys
|
| 9 |
-
import logging
|
| 10 |
-
import atexit
|
| 11 |
-
from typing import Callable, List, Optional
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class ShutdownHandler:
|
| 18 |
-
"""
|
| 19 |
-
Manages graceful lifecycle of the application.
|
| 20 |
-
|
| 21 |
-
Registers cleanup callbacks that are executed when the application
|
| 22 |
-
receives a termination signal (SIGINT, SIGTERM) or exits normally.
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
_instance: Optional['ShutdownHandler'] = None
|
| 26 |
-
|
| 27 |
-
def __new__(cls) -> 'ShutdownHandler':
|
| 28 |
-
"""Singleton pattern to ensure only one handler exists."""
|
| 29 |
-
if cls._instance is None:
|
| 30 |
-
cls._instance = super().__new__(cls)
|
| 31 |
-
cls._instance._initialized = False
|
| 32 |
-
return cls._instance
|
| 33 |
-
|
| 34 |
-
def __init__(self) -> None:
|
| 35 |
-
"""Initialize the lifecycle handler."""
|
| 36 |
-
if self._initialized:
|
| 37 |
-
return
|
| 38 |
-
|
| 39 |
-
self._cleanup_callbacks: List[Callable] = []
|
| 40 |
-
self._lifecycle_in_progress: bool = False
|
| 41 |
-
self._initialized = True
|
| 42 |
-
|
| 43 |
-
# Register signal handlers
|
| 44 |
-
signal.signal(signal.SIGINT, self._signal_handler)
|
| 45 |
-
signal.signal(signal.SIGTERM, self._signal_handler)
|
| 46 |
-
|
| 47 |
-
# Register atexit handler for normal exits
|
| 48 |
-
atexit.register(self._atexit_handler)
|
| 49 |
-
|
| 50 |
-
logger.info("[SHUTDOWN] ShutdownHandler initialized")
|
| 51 |
-
|
| 52 |
-
def register_cleanup(self, callback: Callable, name: str = "") -> None:
|
| 53 |
-
"""
|
| 54 |
-
Register a cleanup callback to be called on lifecycle.
|
| 55 |
-
|
| 56 |
-
Args:
|
| 57 |
-
callback: Function to call during lifecycle
|
| 58 |
-
name: Optional name for logging purposes
|
| 59 |
-
"""
|
| 60 |
-
self._cleanup_callbacks.append((callback, name))
|
| 61 |
-
logger.debug(f"[SHUTDOWN] Registered cleanup callback: {name or callback.__name__}")
|
| 62 |
-
|
| 63 |
-
def _signal_handler(self, signum: int, frame) -> None:
|
| 64 |
-
"""
|
| 65 |
-
Handle termination signals.
|
| 66 |
-
|
| 67 |
-
Args:
|
| 68 |
-
signum: Signal number
|
| 69 |
-
frame: Current stack frame
|
| 70 |
-
"""
|
| 71 |
-
signal_name = signal.Signals(signum).name
|
| 72 |
-
logger.info(f"[SHUTDOWN] Received {signal_name}, initiating graceful lifecycle...")
|
| 73 |
-
|
| 74 |
-
self._execute_cleanup()
|
| 75 |
-
sys.exit(0)
|
| 76 |
-
|
| 77 |
-
def _atexit_handler(self) -> None:
|
| 78 |
-
"""Handle normal application exit."""
|
| 79 |
-
if not self._lifecycle_in_progress:
|
| 80 |
-
logger.info("[SHUTDOWN] Application exiting normally, running cleanup...")
|
| 81 |
-
self._execute_cleanup()
|
| 82 |
-
|
| 83 |
-
def _execute_cleanup(self) -> None:
|
| 84 |
-
"""Execute all registered cleanup callbacks."""
|
| 85 |
-
if self._lifecycle_in_progress:
|
| 86 |
-
return
|
| 87 |
-
|
| 88 |
-
self._lifecycle_in_progress = True
|
| 89 |
-
logger.info(f"[SHUTDOWN] Executing {len(self._cleanup_callbacks)} cleanup callbacks...")
|
| 90 |
-
|
| 91 |
-
for callback, name in reversed(self._cleanup_callbacks):
|
| 92 |
-
try:
|
| 93 |
-
callback_name = name or callback.__name__
|
| 94 |
-
logger.debug(f"[SHUTDOWN] Running cleanup: {callback_name}")
|
| 95 |
-
callback()
|
| 96 |
-
logger.debug(f"[SHUTDOWN] ? Cleanup completed: {callback_name}")
|
| 97 |
-
except Exception as e:
|
| 98 |
-
logger.error(f"[SHUTDOWN] ? Cleanup failed: {e}", exc_info=True)
|
| 99 |
-
|
| 100 |
-
logger.info("[SHUTDOWN] ? All cleanup callbacks executed")
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def cleanup_chroma_db() -> None:
|
| 104 |
-
"""Clean up ChromaDB connections."""
|
| 105 |
-
try:
|
| 106 |
-
# ChromaDB cleanup if needed
|
| 107 |
-
logger.info("[CLEANUP] Cleaning up ChromaDB...")
|
| 108 |
-
# ChromaDB uses SQLite which handles cleanup automatically
|
| 109 |
-
logger.info("[CLEANUP] ? ChromaDB cleanup complete")
|
| 110 |
-
except Exception as e:
|
| 111 |
-
logger.error(f"[CLEANUP] ChromaDB cleanup failed: {e}")
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def cleanup_temp_files() -> None:
|
| 115 |
-
"""Clean up temporary files created during processing."""
|
| 116 |
-
try:
|
| 117 |
-
import tempfile
|
| 118 |
-
import shutil
|
| 119 |
-
|
| 120 |
-
# Clean up any temp directories we created
|
| 121 |
-
temp_base = Path(tempfile.gettempdir())
|
| 122 |
-
|
| 123 |
-
# Only clean up directories that match our pattern
|
| 124 |
-
# Be conservative to avoid deleting user data
|
| 125 |
-
logger.info("[CLEANUP] Temporary file cleanup complete")
|
| 126 |
-
except Exception as e:
|
| 127 |
-
logger.error(f"[CLEANUP] Temp file cleanup failed: {e}")
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def cleanup_logging() -> None:
|
| 131 |
-
"""Flush and close all log handlers."""
|
| 132 |
-
try:
|
| 133 |
-
logger.info("[CLEANUP] Flushing log handlers...")
|
| 134 |
-
|
| 135 |
-
# Get root logger and flush all handlers
|
| 136 |
-
root_logger = logging.getLogger()
|
| 137 |
-
for handler in root_logger.handlers:
|
| 138 |
-
handler.flush()
|
| 139 |
-
|
| 140 |
-
logger.info("[CLEANUP] ? Log handlers flushed")
|
| 141 |
-
except Exception as e:
|
| 142 |
-
# Can't log this since logging might be broken
|
| 143 |
-
print(f"[CLEANUP] Log handler cleanup failed: {e}", file=sys.stderr)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def initialize_lifecycle_handler() -> ShutdownHandler:
|
| 147 |
-
"""
|
| 148 |
-
Initialize the lifecycle handler with default cleanup callbacks.
|
| 149 |
-
|
| 150 |
-
Returns:
|
| 151 |
-
The initialized ShutdownHandler instance
|
| 152 |
-
"""
|
| 153 |
-
handler = ShutdownHandler()
|
| 154 |
-
|
| 155 |
-
# Register default cleanup callbacks (order matters - reverse execution)
|
| 156 |
-
handler.register_cleanup(cleanup_logging, "Logging cleanup")
|
| 157 |
-
handler.register_cleanup(cleanup_temp_files, "Temp files cleanup")
|
| 158 |
-
handler.register_cleanup(cleanup_chroma_db, "ChromaDB cleanup")
|
| 159 |
-
|
| 160 |
-
return handler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
intelligence/accuracy_verifier.py
CHANGED
|
@@ -150,19 +150,38 @@ Provide your verification analysis."""
|
|
| 150 |
feedback_parts.append(f"Additional Details: {verification.additional_details}")
|
| 151 |
return " | ".join(feedback_parts) if feedback_parts else None
|
| 152 |
|
| 153 |
-
def should_retry_research(self, verification: VerificationResult) -> bool:
|
| 154 |
"""Determine if research should be retried."""
|
|
|
|
| 155 |
if verification.supported == "NO" or verification.relevant == "NO":
|
| 156 |
return True
|
| 157 |
-
|
| 158 |
if verification.confidence == "LOW" and (
|
| 159 |
verification.unsupported_claims or verification.contradictions
|
| 160 |
):
|
| 161 |
return True
|
| 162 |
-
|
| 163 |
if verification.supported == "PARTIAL" and verification.contradictions:
|
| 164 |
return True
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
return False
|
| 167 |
|
| 168 |
def check(self, answer: str, documents: List[Document], question: Optional[str] = None) -> Dict:
|
|
@@ -219,7 +238,7 @@ Provide your verification analysis."""
|
|
| 219 |
"verification_report": verification_report,
|
| 220 |
"context_used": context,
|
| 221 |
"structured_result": verification_result.model_dump(),
|
| 222 |
-
"should_retry": self.should_retry_research(verification_result),
|
| 223 |
"feedback": feedback
|
| 224 |
}
|
| 225 |
|
|
@@ -333,18 +352,17 @@ Select the best answer by providing its index (0-based) and explain your reasoni
|
|
| 333 |
for line in response_text.split('\n'):
|
| 334 |
if ':' not in line:
|
| 335 |
continue
|
| 336 |
-
|
| 337 |
key, value = line.split(':', 1)
|
| 338 |
key = key.strip().lower().replace(' ', '_')
|
| 339 |
value = value.strip().upper()
|
| 340 |
|
| 341 |
-
if key == "
|
| 342 |
data["supported"] = "YES" if "YES" in value else ("PARTIAL" if "PARTIAL" in value else "NO")
|
| 343 |
-
elif key == "
|
| 344 |
data["confidence"] = "HIGH" if "HIGH" in value else ("MEDIUM" if "MEDIUM" in value else "LOW")
|
| 345 |
-
elif key == "
|
| 346 |
data["relevant"] = "YES" if "YES" in value else "NO"
|
| 347 |
-
elif key == "
|
| 348 |
if "COMPLETE" in value and "INCOMPLETE" not in value:
|
| 349 |
data["completeness"] = "COMPLETE"
|
| 350 |
elif "PARTIAL" in value:
|
|
|
|
| 150 |
feedback_parts.append(f"Additional Details: {verification.additional_details}")
|
| 151 |
return " | ".join(feedback_parts) if feedback_parts else None
|
| 152 |
|
| 153 |
+
def should_retry_research(self, verification: VerificationResult, verification_report: str = None, feedback: Optional[str] = None) -> bool:
|
| 154 |
"""Determine if research should be retried."""
|
| 155 |
+
# Use structured fields first
|
| 156 |
if verification.supported == "NO" or verification.relevant == "NO":
|
| 157 |
return True
|
|
|
|
| 158 |
if verification.confidence == "LOW" and (
|
| 159 |
verification.unsupported_claims or verification.contradictions
|
| 160 |
):
|
| 161 |
return True
|
|
|
|
| 162 |
if verification.supported == "PARTIAL" and verification.contradictions:
|
| 163 |
return True
|
| 164 |
+
# Also check verification_report string for extra signals (legacy/fallback)
|
| 165 |
+
if verification_report:
|
| 166 |
+
if "Supported: NO" in verification_report:
|
| 167 |
+
logger.warning("[Re-Research] Answer not supported; triggering re-research.")
|
| 168 |
+
return True
|
| 169 |
+
elif "Relevant: NO" in verification_report:
|
| 170 |
+
logger.warning("[Re-Research] Answer not relevant; triggering re-research.")
|
| 171 |
+
return True
|
| 172 |
+
elif "Confidence: LOW" in verification_report and "Supported: PARTIAL" in verification_report:
|
| 173 |
+
logger.warning("[Re-Research] Low confidence with partial support; triggering re-research.")
|
| 174 |
+
return True
|
| 175 |
+
elif "Completeness: INCOMPLETE" in verification_report:
|
| 176 |
+
logger.warning("[Re-Research] Answer is incomplete; triggering re-research.")
|
| 177 |
+
return True
|
| 178 |
+
elif "Completeness: PARTIAL" in verification_report:
|
| 179 |
+
logger.warning("[Re-Research] Answer is partially complete; triggering re-research.")
|
| 180 |
+
return True
|
| 181 |
+
# Check feedback for contradiction/unsupported
|
| 182 |
+
if feedback and ("contradiction" in feedback.lower() or "unsupported" in feedback.lower()):
|
| 183 |
+
logger.warning("[Re-Research] Feedback indicates contradiction/unsupported; triggering re-research.")
|
| 184 |
+
return True
|
| 185 |
return False
|
| 186 |
|
| 187 |
def check(self, answer: str, documents: List[Document], question: Optional[str] = None) -> Dict:
|
|
|
|
| 238 |
"verification_report": verification_report,
|
| 239 |
"context_used": context,
|
| 240 |
"structured_result": verification_result.model_dump(),
|
| 241 |
+
"should_retry": self.should_retry_research(verification_result, verification_report, feedback),
|
| 242 |
"feedback": feedback
|
| 243 |
}
|
| 244 |
|
|
|
|
| 352 |
for line in response_text.split('\n'):
|
| 353 |
if ':' not in line:
|
| 354 |
continue
|
|
|
|
| 355 |
key, value = line.split(':', 1)
|
| 356 |
key = key.strip().lower().replace(' ', '_')
|
| 357 |
value = value.strip().upper()
|
| 358 |
|
| 359 |
+
if key == "supported":
|
| 360 |
data["supported"] = "YES" if "YES" in value else ("PARTIAL" if "PARTIAL" in value else "NO")
|
| 361 |
+
elif key == "confidence":
|
| 362 |
data["confidence"] = "HIGH" if "HIGH" in value else ("MEDIUM" if "MEDIUM" in value else "LOW")
|
| 363 |
+
elif key == "relevant":
|
| 364 |
data["relevant"] = "YES" if "YES" in value else "NO"
|
| 365 |
+
elif key == "completeness":
|
| 366 |
if "COMPLETE" in value and "INCOMPLETE" not in value:
|
| 367 |
data["completeness"] = "COMPLETE"
|
| 368 |
elif "PARTIAL" in value:
|
intelligence/orchestrator.py
CHANGED
|
@@ -44,15 +44,10 @@ class AgentState(TypedDict):
|
|
| 44 |
class AgentWorkflow:
|
| 45 |
"""
|
| 46 |
Orchestrates multi-agent orchestrator for document Q&A.
|
| 47 |
-
|
| 48 |
-
Workflow:
|
| 49 |
-
1. Relevance Check - Determines if documents can answer the question
|
| 50 |
-
2. Research - Generates multiple answer candidates using document context
|
| 51 |
-
3. Verification - Selects the best answer from candidates
|
| 52 |
"""
|
| 53 |
|
| 54 |
-
MAX_RESEARCH_ATTEMPTS: int =
|
| 55 |
-
NUM_RESEARCH_CANDIDATES: int =
|
| 56 |
|
| 57 |
def __init__(self, num_candidates: int = None) -> None:
|
| 58 |
"""Initialize orchestrator with required agents."""
|
|
@@ -173,41 +168,40 @@ Question: {state['question']}
|
|
| 173 |
return {"draft_answer": combined, "verification_report": "Multi-question answer combined."}
|
| 174 |
|
| 175 |
def _check_relevance_step(self, state: AgentState) -> Dict[str, Any]:
|
| 176 |
-
"""Check if retrieved documents are relevant to the question."""
|
| 177 |
logger.debug("Checking context relevance...")
|
| 178 |
-
|
| 179 |
result = self.context_validator.context_validate_with_rewrite(
|
| 180 |
-
question=state["question"],
|
| 181 |
-
retriever=state["retriever"],
|
| 182 |
-
k=20
|
| 183 |
-
max_rewrites=
|
| 184 |
)
|
| 185 |
-
|
| 186 |
-
classification = result["classification"]
|
| 187 |
-
query_used = result["query_used"]
|
| 188 |
-
was_rewritten = result.get("was_rewritten", False)
|
| 189 |
-
|
| 190 |
-
logger.info(f"Relevance: {classification}")
|
| 191 |
-
if was_rewritten:
|
| 192 |
-
logger.debug(f"Query rewritten: {query_used[:60]}...")
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
| 200 |
return {
|
| 201 |
-
"is_relevant":
|
| 202 |
-
"query_used":
|
| 203 |
-
"
|
| 204 |
}
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
def _decide_after_relevance_check(self, state: AgentState) -> str:
|
| 207 |
"""Decide next step after relevance check."""
|
| 208 |
return "relevant" if state["is_relevant"] else "irrelevant"
|
| 209 |
|
| 210 |
-
def
|
| 211 |
"""
|
| 212 |
Execute the full Q&A pipeline.
|
| 213 |
|
|
@@ -221,15 +215,10 @@ Question: {state['question']}
|
|
| 221 |
try:
|
| 222 |
if self.compiled_orchestrator is None:
|
| 223 |
self.compiled_orchestrator = self.build_orchestrator()
|
| 224 |
-
|
| 225 |
-
logger.info(f"Starting pipeline: {question[:80]}...")
|
| 226 |
-
|
| 227 |
-
documents = retriever.invoke(question)
|
| 228 |
-
logger.info(f"Retrieved {len(documents)} documents")
|
| 229 |
|
| 230 |
initial_state: AgentState = {
|
| 231 |
"question": question,
|
| 232 |
-
"documents":
|
| 233 |
"draft_answer": "",
|
| 234 |
"verification_report": "",
|
| 235 |
"is_relevant": False,
|
|
@@ -243,67 +232,24 @@ Question: {state['question']}
|
|
| 243 |
"sub_queries": [],
|
| 244 |
"sub_answers": []
|
| 245 |
}
|
| 246 |
-
|
| 247 |
final_state = self.compiled_orchestrator.invoke(initial_state)
|
| 248 |
-
|
| 249 |
logger.info(f"Pipeline completed (attempts: {final_state.get('research_attempts', 1)})")
|
| 250 |
-
|
| 251 |
return {
|
| 252 |
"draft_answer": final_state["draft_answer"],
|
| 253 |
"verification_report": final_state["verification_report"]
|
| 254 |
}
|
| 255 |
-
|
| 256 |
except Exception as e:
|
| 257 |
logger.error(f"Pipeline failed: {e}", exc_info=True)
|
| 258 |
raise RuntimeError(f"Workflow execution failed: {e}") from e
|
| 259 |
|
| 260 |
-
def _research_step(self, state: AgentState) -> Dict[str, Any]:
|
| 261 |
-
"""Generate multiple answer candidates using the research agent."""
|
| 262 |
-
attempts = state.get("research_attempts", 0) + 1
|
| 263 |
-
feedback = state.get("feedback")
|
| 264 |
-
previous_answer = state.get("draft_answer") if feedback else None
|
| 265 |
-
# Consolidate contradictions and unsupported claims into feedback
|
| 266 |
-
contradictions = state.get("contradictions_for_research", [])
|
| 267 |
-
unsupported_claims = state.get("unsupported_claims_for_research", [])
|
| 268 |
-
feedback_for_research = state.get("feedback_for_research", feedback)
|
| 269 |
-
extra_feedback = ""
|
| 270 |
-
if contradictions:
|
| 271 |
-
extra_feedback += " Contradictions: " + "; ".join(contradictions) + "."
|
| 272 |
-
if unsupported_claims:
|
| 273 |
-
extra_feedback += " Unsupported Claims: " + "; ".join(unsupported_claims) + "."
|
| 274 |
-
# If feedback_for_research is present, append extra_feedback; otherwise, use extra_feedback only
|
| 275 |
-
if feedback_for_research:
|
| 276 |
-
feedback_for_research = feedback_for_research + extra_feedback
|
| 277 |
-
else:
|
| 278 |
-
feedback_for_research = extra_feedback.strip()
|
| 279 |
-
logger.info(f"Research step (attempt {attempts}/{self.MAX_RESEARCH_ATTEMPTS})")
|
| 280 |
-
logger.info(f"Generating {self.NUM_RESEARCH_CANDIDATES} candidate answers...")
|
| 281 |
-
candidate_answers = []
|
| 282 |
-
for i in range(self.NUM_RESEARCH_CANDIDATES):
|
| 283 |
-
logger.info(f"Generating candidate {i + 1}/{self.NUM_RESEARCH_CANDIDATES}")
|
| 284 |
-
result = self.researcher.generate(
|
| 285 |
-
question=state["question"],
|
| 286 |
-
documents=state["documents"],
|
| 287 |
-
feedback=feedback_for_research,
|
| 288 |
-
previous_answer=previous_answer
|
| 289 |
-
)
|
| 290 |
-
candidate_answers.append(result["draft_answer"])
|
| 291 |
-
logger.info(f"Generated {len(candidate_answers)} candidate answers")
|
| 292 |
-
return {
|
| 293 |
-
"candidate_answers": candidate_answers,
|
| 294 |
-
"research_attempts": attempts,
|
| 295 |
-
"feedback": None
|
| 296 |
-
}
|
| 297 |
-
|
| 298 |
def _verification_step(self, state: AgentState) -> Dict[str, Any]:
|
| 299 |
"""Select the best answer from candidates and verify it."""
|
| 300 |
logger.debug("Selecting best answer from candidates...")
|
| 301 |
|
| 302 |
-
candidate_answers = state.get("candidate_answers", [])
|
| 303 |
-
|
| 304 |
-
if not candidate_answers:
|
| 305 |
-
logger.warning("No candidate answers found, using draft_answer")
|
| 306 |
-
candidate_answers = [state.get("draft_answer", "")]
|
| 307 |
|
| 308 |
# Select the best answer from candidates
|
| 309 |
selection_result = self.verifier.select_best_answer(
|
|
@@ -331,58 +277,45 @@ Question: {state['question']}
|
|
| 331 |
f"**Selection Confidence:** {selection_result.get('confidence', 'N/A')}\n" + \
|
| 332 |
f"**Selection Reasoning:** {selection_reasoning}\n\n" + \
|
| 333 |
verification_report
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
return {
|
| 336 |
"draft_answer": best_answer,
|
| 337 |
"verification_report": verification_report,
|
| 338 |
-
"
|
| 339 |
-
"selection_reasoning": selection_reasoning
|
|
|
|
| 340 |
}
|
| 341 |
|
| 342 |
def _decide_next_step(self, state: AgentState) -> str:
|
| 343 |
"""Decide whether to re-research or end orchestrator."""
|
| 344 |
-
verification_report = state["verification_report"]
|
| 345 |
research_attempts = state.get("research_attempts", 1)
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
# Extract contradictions and unsupported claims for feedback
|
| 349 |
-
contradictions = []
|
| 350 |
-
unsupported_claims = []
|
| 351 |
-
import re
|
| 352 |
-
for line in verification_report.splitlines():
|
| 353 |
-
if line.startswith("**Contradictions:"):
|
| 354 |
-
contradictions = [c.strip() for c in line.split(":", 1)[-1].split(",") if c.strip() and c.strip().lower() != "none"]
|
| 355 |
-
if line.startswith("**Unsupported Claims:"):
|
| 356 |
-
unsupported_claims = [u.strip() for u in line.split(":", 1)[-1].split(",") if u.strip() and u.strip().lower() != "none"]
|
| 357 |
-
if "Supported: NO" in verification_report:
|
| 358 |
-
needs_re_research = True
|
| 359 |
-
logger.warning("[Re-Research] Answer not supported; triggering re-research.")
|
| 360 |
-
elif "Relevant: NO" in verification_report:
|
| 361 |
-
needs_re_research = True
|
| 362 |
-
logger.warning("[Re-Research] Answer not relevant; triggering re-research.")
|
| 363 |
-
elif "Confidence: LOW" in verification_report and "Supported: PARTIAL" in verification_report:
|
| 364 |
-
needs_re_research = True
|
| 365 |
-
logger.warning("[Re-Research] Low confidence with partial support; triggering re-research.")
|
| 366 |
-
elif "Completeness: INCOMPLETE" in verification_report:
|
| 367 |
-
needs_re_research = True
|
| 368 |
-
logger.warning("[Re-Research] Answer is incomplete; triggering re-research.")
|
| 369 |
-
elif "Completeness: PARTIAL" in verification_report:
|
| 370 |
-
needs_re_research = True
|
| 371 |
-
logger.warning("[Re-Research] Answer is partially complete; triggering re-research.")
|
| 372 |
-
if feedback and not needs_re_research:
|
| 373 |
-
if "contradiction" in feedback.lower() or "unsupported" in feedback.lower():
|
| 374 |
-
needs_re_research = True
|
| 375 |
-
logger.warning("[Re-Research] Feedback indicates contradiction/unsupported; triggering re-research.")
|
| 376 |
-
# Store extra feedback for research node
|
| 377 |
-
state["contradictions_for_research"] = contradictions
|
| 378 |
-
state["unsupported_claims_for_research"] = unsupported_claims
|
| 379 |
-
state["feedback_for_research"] = feedback
|
| 380 |
-
if needs_re_research and research_attempts < self.MAX_RESEARCH_ATTEMPTS:
|
| 381 |
-
logger.info(f"[Re-Research] Re-researching (attempt {research_attempts + 1})")
|
| 382 |
return "re_research"
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
class AgentWorkflow:
|
| 45 |
"""
|
| 46 |
Orchestrates multi-agent orchestrator for document Q&A.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
| 48 |
|
| 49 |
+
MAX_RESEARCH_ATTEMPTS: int = parameters.MAX_RESEARCH_ATTEMPTS
|
| 50 |
+
NUM_RESEARCH_CANDIDATES: int = parameters.NUM_RESEARCH_CANDIDATES
|
| 51 |
|
| 52 |
def __init__(self, num_candidates: int = None) -> None:
|
| 53 |
"""Initialize orchestrator with required agents."""
|
|
|
|
| 168 |
return {"draft_answer": combined, "verification_report": "Multi-question answer combined."}
|
| 169 |
|
| 170 |
def _check_relevance_step(self, state: AgentState) -> Dict[str, Any]:
|
|
|
|
| 171 |
logger.debug("Checking context relevance...")
|
| 172 |
+
|
| 173 |
result = self.context_validator.context_validate_with_rewrite(
|
| 174 |
+
question=state["question"],
|
| 175 |
+
retriever=state["retriever"],
|
| 176 |
+
k=parameters.RELEVANCE_CHECK_K, # use config instead of hardcoding 20
|
| 177 |
+
max_rewrites=parameters.MAX_QUERY_REWRITES,
|
| 178 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
+
classification = result.get("classification", "NO_MATCH")
|
| 181 |
+
query_used = result.get("query_used", state["question"])
|
| 182 |
+
|
| 183 |
+
logger.info(f"Relevance: {classification} (query_used={query_used[:80]})")
|
| 184 |
+
|
| 185 |
+
if classification in ("CAN_ANSWER", "PARTIAL"):
|
| 186 |
+
# ? ALWAYS retrieve docs for the query we�re actually going to answer
|
| 187 |
+
documents = state["retriever"].invoke(query_used)
|
| 188 |
return {
|
| 189 |
+
"is_relevant": True,
|
| 190 |
+
"query_used": query_used,
|
| 191 |
+
"documents": documents
|
| 192 |
}
|
| 193 |
|
| 194 |
+
return {
|
| 195 |
+
"is_relevant": False,
|
| 196 |
+
"query_used": query_used,
|
| 197 |
+
"draft_answer": "This question isn't related to the uploaded documents. Please ask another question.",
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
def _decide_after_relevance_check(self, state: AgentState) -> str:
|
| 201 |
"""Decide next step after relevance check."""
|
| 202 |
return "relevant" if state["is_relevant"] else "irrelevant"
|
| 203 |
|
| 204 |
+
def run_workflow(self, question: str, retriever: BaseRetriever) -> Dict[str, str]:
|
| 205 |
"""
|
| 206 |
Execute the full Q&A pipeline.
|
| 207 |
|
|
|
|
| 215 |
try:
|
| 216 |
if self.compiled_orchestrator is None:
|
| 217 |
self.compiled_orchestrator = self.build_orchestrator()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
initial_state: AgentState = {
|
| 220 |
"question": question,
|
| 221 |
+
"documents": [], # Let _check_relevance_step fill this
|
| 222 |
"draft_answer": "",
|
| 223 |
"verification_report": "",
|
| 224 |
"is_relevant": False,
|
|
|
|
| 232 |
"sub_queries": [],
|
| 233 |
"sub_answers": []
|
| 234 |
}
|
| 235 |
+
|
| 236 |
final_state = self.compiled_orchestrator.invoke(initial_state)
|
| 237 |
+
|
| 238 |
logger.info(f"Pipeline completed (attempts: {final_state.get('research_attempts', 1)})")
|
| 239 |
+
|
| 240 |
return {
|
| 241 |
"draft_answer": final_state["draft_answer"],
|
| 242 |
"verification_report": final_state["verification_report"]
|
| 243 |
}
|
|
|
|
| 244 |
except Exception as e:
|
| 245 |
logger.error(f"Pipeline failed: {e}", exc_info=True)
|
| 246 |
raise RuntimeError(f"Workflow execution failed: {e}") from e
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
def _verification_step(self, state: AgentState) -> Dict[str, Any]:
|
| 249 |
"""Select the best answer from candidates and verify it."""
|
| 250 |
logger.debug("Selecting best answer from candidates...")
|
| 251 |
|
| 252 |
+
candidate_answers = state.get("candidate_answers", []) or [state.get("draft_answer", "")]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
# Select the best answer from candidates
|
| 255 |
selection_result = self.verifier.select_best_answer(
|
|
|
|
| 277 |
f"**Selection Confidence:** {selection_result.get('confidence', 'N/A')}\n" + \
|
| 278 |
f"**Selection Reasoning:** {selection_reasoning}\n\n" + \
|
| 279 |
verification_report
|
| 280 |
+
|
| 281 |
+
feedback_for_research = verification_result.get("feedback")
|
| 282 |
+
|
| 283 |
return {
|
| 284 |
"draft_answer": best_answer,
|
| 285 |
"verification_report": verification_report,
|
| 286 |
+
"feedback_for_research": feedback_for_research,
|
| 287 |
+
"selection_reasoning": selection_reasoning,
|
| 288 |
+
"should_retry": verification_result.get("should_retry", False),
|
| 289 |
}
|
| 290 |
|
| 291 |
def _decide_next_step(self, state: AgentState) -> str:
|
| 292 |
"""Decide whether to re-research or end orchestrator."""
|
|
|
|
| 293 |
research_attempts = state.get("research_attempts", 1)
|
| 294 |
+
should_retry = bool(state.get("should_retry", False))
|
| 295 |
+
if should_retry and research_attempts < self.MAX_RESEARCH_ATTEMPTS:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
return "re_research"
|
| 297 |
+
return "end"
|
| 298 |
+
|
| 299 |
+
def _research_step(self, state: AgentState) -> Dict[str, Any]:
|
| 300 |
+
"""Generate multiple answer candidates using the research agent."""
|
| 301 |
+
attempts = state.get("research_attempts", 0) + 1
|
| 302 |
+
feedback_for_research = state.get("feedback_for_research")
|
| 303 |
+
previous_answer = state.get("draft_answer") if feedback_for_research else None
|
| 304 |
+
logger.info(f"Research step (attempt {attempts}/{self.MAX_RESEARCH_ATTEMPTS})")
|
| 305 |
+
logger.info(f"Generating {self.NUM_RESEARCH_CANDIDATES} candidate answers...")
|
| 306 |
+
candidate_answers = []
|
| 307 |
+
for i in range(self.NUM_RESEARCH_CANDIDATES):
|
| 308 |
+
logger.info(f"Generating candidate {i + 1}/{self.NUM_RESEARCH_CANDIDATES}")
|
| 309 |
+
result = self.researcher.generate(
|
| 310 |
+
question=state["question"],
|
| 311 |
+
documents=state["documents"],
|
| 312 |
+
feedback=feedback_for_research,
|
| 313 |
+
previous_answer=previous_answer
|
| 314 |
+
)
|
| 315 |
+
candidate_answers.append(result["draft_answer"])
|
| 316 |
+
logger.info(f"Generated {len(candidate_answers)} candidate answers")
|
| 317 |
+
return {
|
| 318 |
+
"candidate_answers": candidate_answers,
|
| 319 |
+
"research_attempts": attempts,
|
| 320 |
+
"feedback": None
|
| 321 |
+
}
|
main.py
CHANGED
|
@@ -17,7 +17,7 @@ from content_analyzer.document_parser import DocumentProcessor
|
|
| 17 |
from search_engine.indexer import RetrieverBuilder
|
| 18 |
from intelligence.orchestrator import AgentWorkflow
|
| 19 |
from configuration import definitions, parameters
|
| 20 |
-
|
| 21 |
|
| 22 |
# Example data for demo
|
| 23 |
EXAMPLES = {
|
|
@@ -127,9 +127,26 @@ def _find_open_port(start_port: int, max_attempts: int = 20) -> int:
|
|
| 127 |
raise RuntimeError(f"Could not find an open port starting at {start_port}")
|
| 128 |
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
def _setup_gradio_shim():
|
| 131 |
"""Shim Gradio's JSON schema conversion to tolerate boolean additionalProperties values."""
|
| 132 |
-
import gradio as gr
|
| 133 |
from gradio_client import utils as grc_utils
|
| 134 |
_orig_json_schema_to_python_type = grc_utils._json_schema_to_python_type
|
| 135 |
def _json_schema_to_python_type_safe(schema, defs=None):
|
|
@@ -140,7 +157,8 @@ def _setup_gradio_shim():
|
|
| 140 |
|
| 141 |
|
| 142 |
def main():
|
| 143 |
-
|
|
|
|
| 144 |
_setup_gradio_shim()
|
| 145 |
|
| 146 |
logger.info("=" * 60)
|
|
@@ -499,36 +517,9 @@ def main():
|
|
| 499 |
margin-bottom: 16px !important;
|
| 500 |
}
|
| 501 |
"""
|
| 502 |
-
js = """
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
container.id = 'gradio-animation';
|
| 506 |
-
container.style.fontSize = '2.4em';
|
| 507 |
-
container.style.fontWeight = '700';
|
| 508 |
-
container.style.textAlign = 'center';
|
| 509 |
-
container.style.marginBottom = '20px';
|
| 510 |
-
container.style.marginTop = '10px';
|
| 511 |
-
container.style.color = '#0369a1';
|
| 512 |
-
container.style.letterSpacing = '-0.02em';
|
| 513 |
-
var text = '📄 SmartDoc AI';
|
| 514 |
-
for (var i = 0; i < text.length; i++) {
|
| 515 |
-
(function(i){
|
| 516 |
-
setTimeout(function(){
|
| 517 |
-
var letter = document.createElement('span');
|
| 518 |
-
letter.style.opacity = '0';
|
| 519 |
-
letter.style.transition = 'opacity 0.2s ease';
|
| 520 |
-
letter.innerText = text[i];
|
| 521 |
-
container.appendChild(letter);
|
| 522 |
-
setTimeout(function() { letter.style.opacity = '1'; }, 50);
|
| 523 |
-
}, i * 80);
|
| 524 |
-
})(i);
|
| 525 |
-
}
|
| 526 |
-
var gradioContainer = document.querySelector('.gradio-container');
|
| 527 |
-
gradioContainer.insertBefore(container, gradioContainer.firstChild);
|
| 528 |
-
return 'Animation created';
|
| 529 |
-
}
|
| 530 |
-
(() => {
|
| 531 |
-
const upload_messages = [
|
| 532 |
"Crunching your documents...",
|
| 533 |
"Warming up the AI...",
|
| 534 |
"Extracting knowledge...",
|
|
@@ -541,99 +532,69 @@ def main():
|
|
| 541 |
"Almost ready..."
|
| 542 |
];
|
| 543 |
|
| 544 |
-
let
|
| 545 |
-
let
|
| 546 |
-
let startMs =
|
| 547 |
let lastMsg = null;
|
| 548 |
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
let m;
|
| 553 |
-
do {
|
| 554 |
-
|
| 555 |
-
} while (m === lastMsg);
|
| 556 |
lastMsg = m;
|
| 557 |
return m;
|
| 558 |
-
}
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
}
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
if (!root) return null;
|
| 569 |
-
return root.querySelector("#processing-timer");
|
| 570 |
-
}
|
| 571 |
-
|
| 572 |
-
function setMsg(text) {
|
| 573 |
-
const span = getMsgSpan();
|
| 574 |
-
if (!span) return;
|
| 575 |
-
span.textContent = text;
|
| 576 |
-
}
|
| 577 |
-
|
| 578 |
-
function formatElapsed(startMs) {
|
| 579 |
-
const s = (Date.now() - startMs) / 1000;
|
| 580 |
-
return `${s.toFixed(1)}s elapsed`;
|
| 581 |
-
}
|
| 582 |
-
|
| 583 |
-
function startRotationAndTimer() {
|
| 584 |
-
stopRotationAndTimer();
|
| 585 |
-
setMsg(pickMsg());
|
| 586 |
startMs = Date.now();
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
if (
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
const
|
| 607 |
-
if (
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
}
|
| 617 |
-
const isVisible = () => root.offsetParent !== null;
|
| 618 |
-
let prev = isVisible();
|
| 619 |
-
if (prev) startRotationAndTimer();
|
| 620 |
-
|
| 621 |
-
const obs = new MutationObserver(() => {
|
| 622 |
-
const now = isVisible();
|
| 623 |
-
if (now && !prev) startRotationAndTimer();
|
| 624 |
-
if (!now && prev) stopRotationAndTimer();
|
| 625 |
-
prev = now;
|
| 626 |
-
});
|
| 627 |
-
|
| 628 |
-
obs.observe(root, { attributes: true, attributeFilter: ["style", "class"] });
|
| 629 |
-
}
|
| 630 |
-
|
| 631 |
-
window.smartdocStartRotationAndTimer = startRotationAndTimer;
|
| 632 |
-
window.smartdocStopRotationAndTimer = stopRotationAndTimer;
|
| 633 |
-
|
| 634 |
-
watchProcessingBox();
|
| 635 |
})();
|
| 636 |
-
|
| 637 |
|
| 638 |
with gr.Blocks(theme=gr.themes.Soft(), title="SmartDoc AI", css=css, js=js) as demo:
|
| 639 |
gr.Markdown("### SmartDoc AI - Document Q&A", elem_classes="app-title")
|
|
@@ -668,26 +629,8 @@ def main():
|
|
| 668 |
"session_start": datetime.now().strftime("%Y-%m-%d %H:%M")
|
| 669 |
})
|
| 670 |
|
| 671 |
-
def process_question(question_text, uploaded_files, chat_history):
|
| 672 |
-
|
| 673 |
-
import random
|
| 674 |
-
chat_history = chat_history or []
|
| 675 |
-
upload_messages = [
|
| 676 |
-
"Crunching your documents...",
|
| 677 |
-
"Warming up the AI...",
|
| 678 |
-
"Extracting knowledge...",
|
| 679 |
-
"Scanning for insights...",
|
| 680 |
-
"Preparing your data...",
|
| 681 |
-
"Looking for answers...",
|
| 682 |
-
"Analyzing file structure...",
|
| 683 |
-
"Reading your files...",
|
| 684 |
-
"Indexing content...",
|
| 685 |
-
"Almost ready..."
|
| 686 |
-
]
|
| 687 |
-
last_msg = None
|
| 688 |
-
start_time = time.time()
|
| 689 |
-
msg = random.choice([m for m in upload_messages if m != last_msg])
|
| 690 |
-
last_msg = msg
|
| 691 |
yield (
|
| 692 |
chat_history,
|
| 693 |
gr.update(visible=False),
|
|
@@ -701,7 +644,6 @@ def main():
|
|
| 701 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 702 |
</div>''', visible=True)
|
| 703 |
)
|
| 704 |
-
|
| 705 |
try:
|
| 706 |
if not question_text.strip():
|
| 707 |
chat_history.append({"role": "user", "content": question_text})
|
|
@@ -732,42 +674,32 @@ def main():
|
|
| 732 |
)
|
| 733 |
return
|
| 734 |
# Stage 2: Chunking with per-chunk progress and rotating status
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
total_chunks = 0
|
| 738 |
-
chunk_counts = []
|
| 739 |
-
for file in uploaded_files:
|
| 740 |
-
with open(file.name, 'rb') as f:
|
| 741 |
file_content = f.read()
|
| 742 |
-
|
| 743 |
cache_path = processor.cache_dir / f"{file_hash}.pkl"
|
| 744 |
if processor._is_cache_valid(cache_path):
|
| 745 |
chunks = processor._load_from_cache(cache_path)
|
| 746 |
-
if
|
| 747 |
-
chunks
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
total_chunks += len(chunks)
|
| 754 |
if total_chunks == 0:
|
| 755 |
total_chunks = 1
|
| 756 |
chunk_idx = 0
|
| 757 |
-
|
| 758 |
-
for file, file_chunk_count in zip(uploaded_files, chunk_counts):
|
| 759 |
-
with open(file.name, 'rb') as f:
|
| 760 |
-
file_content = f.read()
|
| 761 |
-
file_hash = processor._generate_hash(file_content)
|
| 762 |
-
cache_path = processor.cache_dir / f"{file_hash}.pkl"
|
| 763 |
-
if processor._is_cache_valid(cache_path):
|
| 764 |
-
chunks = processor._load_from_cache(cache_path)
|
| 765 |
-
if not chunks:
|
| 766 |
-
chunks = processor._process_file(file)
|
| 767 |
-
processor._save_to_cache(chunks, cache_path)
|
| 768 |
-
else:
|
| 769 |
-
chunks = processor._process_file(file)
|
| 770 |
-
processor._save_to_cache(chunks, cache_path)
|
| 771 |
for chunk in chunks:
|
| 772 |
chunk_hash = processor._generate_hash(chunk.page_content.encode())
|
| 773 |
if chunk_hash not in seen_hashes:
|
|
@@ -775,12 +707,7 @@ def main():
|
|
| 775 |
all_chunks.append(chunk)
|
| 776 |
# else: skip duplicate chunk
|
| 777 |
chunk_idx += 1
|
| 778 |
-
#
|
| 779 |
-
elapsed = time.time() - start_time
|
| 780 |
-
if chunk_idx == 1 or (elapsed // 10) > ((elapsed-1) // 10):
|
| 781 |
-
msg = random.choice([m for m in upload_messages if m != last_msg])
|
| 782 |
-
last_msg = msg
|
| 783 |
-
# When yielding progress, always do:
|
| 784 |
yield (
|
| 785 |
chat_history,
|
| 786 |
gr.update(visible=False),
|
|
@@ -794,8 +721,7 @@ def main():
|
|
| 794 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 795 |
</div>''', visible=True)
|
| 796 |
)
|
| 797 |
-
# After all chunks, show 100%
|
| 798 |
-
elapsed = time.time() - start_time
|
| 799 |
yield (
|
| 800 |
chat_history,
|
| 801 |
gr.update(visible=False),
|
|
@@ -810,7 +736,6 @@ def main():
|
|
| 810 |
</div>''', visible=True)
|
| 811 |
)
|
| 812 |
# Stage 3: Building Retriever
|
| 813 |
-
elapsed = time.time() - start_time
|
| 814 |
yield (
|
| 815 |
chat_history,
|
| 816 |
gr.update(visible=False),
|
|
@@ -828,7 +753,6 @@ def main():
|
|
| 828 |
)
|
| 829 |
retriever = retriever_indexer.build_hybrid_retriever(all_chunks)
|
| 830 |
# Stage 4: Generating Answer
|
| 831 |
-
elapsed = time.time() - start_time
|
| 832 |
yield (
|
| 833 |
chat_history,
|
| 834 |
gr.update(visible=False),
|
|
@@ -842,10 +766,9 @@ def main():
|
|
| 842 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 843 |
</div>''', visible=True)
|
| 844 |
)
|
| 845 |
-
result = orchestrator.
|
| 846 |
answer = result["draft_answer"]
|
| 847 |
# Stage 5: Verifying Answer
|
| 848 |
-
elapsed = time.time() - start_time
|
| 849 |
yield (
|
| 850 |
chat_history,
|
| 851 |
gr.update(visible=False),
|
|
@@ -864,10 +787,7 @@ def main():
|
|
| 864 |
# Do not display verification to user, only use internally
|
| 865 |
chat_history.append({"role": "user", "content": question_text})
|
| 866 |
chat_history.append({"role": "assistant", "content": f"**Answer:**\n{answer}"})
|
| 867 |
-
|
| 868 |
session_state.value["last_documents"] = retriever.invoke(question_text)
|
| 869 |
-
# Final: Show results and make context tab visible
|
| 870 |
-
total_elapsed = time.time() - start_time
|
| 871 |
yield (
|
| 872 |
chat_history,
|
| 873 |
gr.update(visible=True), # doc_context_display
|
|
@@ -880,9 +800,7 @@ def main():
|
|
| 880 |
<span id="processing-msg"></span>
|
| 881 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 882 |
</div>''', visible=True)
|
| 883 |
-
)
|
| 884 |
-
|
| 885 |
-
time.sleep(1.5)
|
| 886 |
yield (
|
| 887 |
chat_history,
|
| 888 |
gr.update(visible=True),
|
|
@@ -954,10 +872,8 @@ def main():
|
|
| 954 |
file_info_text += f"{source_file_path} not found\n"
|
| 955 |
if not copied_files:
|
| 956 |
return [], "", "Could not load example files"
|
| 957 |
-
return copied_files, question_text, file_info_text
|
| 958 |
-
|
| 959 |
-
# Remove the Load Example button and related logic
|
| 960 |
-
# Instead, load the example immediately when dropdown changes
|
| 961 |
example_dropdown.change(
|
| 962 |
fn=load_example,
|
| 963 |
inputs=[example_dropdown],
|
|
@@ -967,6 +883,7 @@ def main():
|
|
| 967 |
# HF Spaces sets SPACE_ID environment variable
|
| 968 |
is_hf_space = os.environ.get("SPACE_ID") is not None
|
| 969 |
|
|
|
|
| 970 |
if is_hf_space:
|
| 971 |
# Hugging Face Spaces configuration
|
| 972 |
logger.info("Running on Hugging Face Spaces")
|
|
@@ -975,10 +892,8 @@ def main():
|
|
| 975 |
# Local development configuration
|
| 976 |
configured_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
| 977 |
server_port = _find_open_port(configured_port)
|
| 978 |
-
|
| 979 |
logger.info(f"Launching Gradio on port {server_port}")
|
| 980 |
logger.info(f"Access the app at: http://127.0.0.1:{server_port}")
|
| 981 |
-
|
| 982 |
demo.launch(server_name="127.0.0.1", server_port=server_port, share=False)
|
| 983 |
|
| 984 |
|
|
|
|
| 17 |
from search_engine.indexer import RetrieverBuilder
|
| 18 |
from intelligence.orchestrator import AgentWorkflow
|
| 19 |
from configuration import definitions, parameters
|
| 20 |
+
|
| 21 |
|
| 22 |
# Example data for demo
|
| 23 |
EXAMPLES = {
|
|
|
|
| 127 |
raise RuntimeError(f"Could not find an open port starting at {start_port}")
|
| 128 |
|
| 129 |
|
| 130 |
+
def _ensure_hfhub_hffolder_compat():
|
| 131 |
+
"""
|
| 132 |
+
Shim for Gradio <5.7.1 with huggingface_hub >=1.0.
|
| 133 |
+
"""
|
| 134 |
+
import huggingface_hub
|
| 135 |
+
if hasattr(huggingface_hub, "HfFolder"):
|
| 136 |
+
return
|
| 137 |
+
try:
|
| 138 |
+
from huggingface_hub.utils import get_token
|
| 139 |
+
except Exception:
|
| 140 |
+
return
|
| 141 |
+
class HfFolder:
|
| 142 |
+
@staticmethod
|
| 143 |
+
def get_token():
|
| 144 |
+
return get_token()
|
| 145 |
+
huggingface_hub.HfFolder = HfFolder
|
| 146 |
+
|
| 147 |
+
|
| 148 |
def _setup_gradio_shim():
|
| 149 |
"""Shim Gradio's JSON schema conversion to tolerate boolean additionalProperties values."""
|
|
|
|
| 150 |
from gradio_client import utils as grc_utils
|
| 151 |
_orig_json_schema_to_python_type = grc_utils._json_schema_to_python_type
|
| 152 |
def _json_schema_to_python_type_safe(schema, defs=None):
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def main():
|
| 160 |
+
_ensure_hfhub_hffolder_compat() # must run before importing gradio
|
| 161 |
+
import gradio as gr
|
| 162 |
_setup_gradio_shim()
|
| 163 |
|
| 164 |
logger.info("=" * 60)
|
|
|
|
| 517 |
margin-bottom: 16px !important;
|
| 518 |
}
|
| 519 |
"""
|
| 520 |
+
js = r"""
|
| 521 |
+
(() => {
|
| 522 |
+
const uploadMessages = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
"Crunching your documents...",
|
| 524 |
"Warming up the AI...",
|
| 525 |
"Extracting knowledge...",
|
|
|
|
| 532 |
"Almost ready..."
|
| 533 |
];
|
| 534 |
|
| 535 |
+
let msgInterval = null;
|
| 536 |
+
let timerInterval = null;
|
| 537 |
+
let startMs = 0;
|
| 538 |
let lastMsg = null;
|
| 539 |
|
| 540 |
+
// In Gradio re-renders, the element may get replaced; pick the visible one if duplicates ever appear
|
| 541 |
+
const root = () => {
|
| 542 |
+
const all = Array.from(document.querySelectorAll("#processing-message"));
|
| 543 |
+
return all.find(el => el && (el.offsetWidth || el.offsetHeight || el.getClientRects().length)) || all[0] || null;
|
| 544 |
+
};
|
| 545 |
+
|
| 546 |
+
const isVisible = (el) => !!(el && (el.offsetWidth || el.offsetHeight || el.getClientRects().length));
|
| 547 |
+
|
| 548 |
+
const pickMsg = () => {
|
| 549 |
+
if (uploadMessages.length === 0) return "";
|
| 550 |
+
if (uploadMessages.length === 1) return uploadMessages[0];
|
| 551 |
let m;
|
| 552 |
+
do { m = uploadMessages[Math.floor(Math.random() * uploadMessages.length)]; }
|
| 553 |
+
while (m === lastMsg);
|
|
|
|
| 554 |
lastMsg = m;
|
| 555 |
return m;
|
| 556 |
+
};
|
| 557 |
+
|
| 558 |
+
const getMsgSpan = () => root()?.querySelector("#processing-msg");
|
| 559 |
+
const getTimerSpan = () => root()?.querySelector("#processing-timer");
|
| 560 |
+
|
| 561 |
+
const setMsg = (t) => { const s = getMsgSpan(); if (s) s.textContent = t; };
|
| 562 |
+
const fmtElapsed = () => `${((Date.now() - startMs) / 1000).toFixed(1)}s elapsed`;
|
| 563 |
+
|
| 564 |
+
const start = () => {
|
| 565 |
+
if (msgInterval || timerInterval) return;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
startMs = Date.now();
|
| 567 |
+
setMsg(pickMsg());
|
| 568 |
+
|
| 569 |
+
msgInterval = setInterval(() => setMsg(pickMsg()), 2000);
|
| 570 |
+
|
| 571 |
+
const t = getTimerSpan();
|
| 572 |
+
if (t) {
|
| 573 |
+
t.textContent = fmtElapsed();
|
| 574 |
+
timerInterval = setInterval(() => { t.textContent = fmtElapsed(); }, 200);
|
| 575 |
+
}
|
| 576 |
+
};
|
| 577 |
+
|
| 578 |
+
const stop = () => {
|
| 579 |
+
if (msgInterval) { clearInterval(msgInterval); msgInterval = null; }
|
| 580 |
+
if (timerInterval) { clearInterval(timerInterval); timerInterval = null; }
|
| 581 |
+
const t = getTimerSpan();
|
| 582 |
+
if (t) t.textContent = "";
|
| 583 |
+
};
|
| 584 |
+
|
| 585 |
+
const tick = () => {
|
| 586 |
+
const r = root();
|
| 587 |
+
if (isVisible(r)) start();
|
| 588 |
+
else stop();
|
| 589 |
+
};
|
| 590 |
+
|
| 591 |
+
const obs = new MutationObserver(tick);
|
| 592 |
+
obs.observe(document.body, { subtree: true, childList: true, attributes: true });
|
| 593 |
+
|
| 594 |
+
window.addEventListener("load", tick);
|
| 595 |
+
setInterval(tick, 500);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
})();
|
| 597 |
+
"""
|
| 598 |
|
| 599 |
with gr.Blocks(theme=gr.themes.Soft(), title="SmartDoc AI", css=css, js=js) as demo:
|
| 600 |
gr.Markdown("### SmartDoc AI - Document Q&A", elem_classes="app-title")
|
|
|
|
| 629 |
"session_start": datetime.now().strftime("%Y-%m-%d %H:%M")
|
| 630 |
})
|
| 631 |
|
| 632 |
+
def process_question(question_text, uploaded_files, chat_history):
|
| 633 |
+
chat_history = chat_history or []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
yield (
|
| 635 |
chat_history,
|
| 636 |
gr.update(visible=False),
|
|
|
|
| 644 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 645 |
</div>''', visible=True)
|
| 646 |
)
|
|
|
|
| 647 |
try:
|
| 648 |
if not question_text.strip():
|
| 649 |
chat_history.append({"role": "user", "content": question_text})
|
|
|
|
| 674 |
)
|
| 675 |
return
|
| 676 |
# Stage 2: Chunking with per-chunk progress and rotating status
|
| 677 |
+
def load_or_process(file):
|
| 678 |
+
with open(file.name, "rb") as f:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
file_content = f.read()
|
| 680 |
+
file_hash = processor._generate_hash(file_content)
|
| 681 |
cache_path = processor.cache_dir / f"{file_hash}.pkl"
|
| 682 |
if processor._is_cache_valid(cache_path):
|
| 683 |
chunks = processor._load_from_cache(cache_path)
|
| 684 |
+
if chunks:
|
| 685 |
+
logger.info(f"Using cached chunks for {file.name}")
|
| 686 |
+
return chunks
|
| 687 |
+
chunks = processor._process_file(file)
|
| 688 |
+
processor._save_to_cache(chunks, cache_path)
|
| 689 |
+
return chunks
|
| 690 |
+
|
| 691 |
+
all_chunks = []
|
| 692 |
+
seen_hashes = set()
|
| 693 |
+
chunks_by_file = []
|
| 694 |
+
total_chunks = 0
|
| 695 |
+
for file in uploaded_files:
|
| 696 |
+
chunks = load_or_process(file)
|
| 697 |
+
chunks_by_file.append(chunks)
|
| 698 |
total_chunks += len(chunks)
|
| 699 |
if total_chunks == 0:
|
| 700 |
total_chunks = 1
|
| 701 |
chunk_idx = 0
|
| 702 |
+
for chunks in chunks_by_file:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
for chunk in chunks:
|
| 704 |
chunk_hash = processor._generate_hash(chunk.page_content.encode())
|
| 705 |
if chunk_hash not in seen_hashes:
|
|
|
|
| 707 |
all_chunks.append(chunk)
|
| 708 |
# else: skip duplicate chunk
|
| 709 |
chunk_idx += 1
|
| 710 |
+
# yield progress here if needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
yield (
|
| 712 |
chat_history,
|
| 713 |
gr.update(visible=False),
|
|
|
|
| 721 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 722 |
</div>''', visible=True)
|
| 723 |
)
|
| 724 |
+
# After all chunks, show 100%
|
|
|
|
| 725 |
yield (
|
| 726 |
chat_history,
|
| 727 |
gr.update(visible=False),
|
|
|
|
| 736 |
</div>''', visible=True)
|
| 737 |
)
|
| 738 |
# Stage 3: Building Retriever
|
|
|
|
| 739 |
yield (
|
| 740 |
chat_history,
|
| 741 |
gr.update(visible=False),
|
|
|
|
| 753 |
)
|
| 754 |
retriever = retriever_indexer.build_hybrid_retriever(all_chunks)
|
| 755 |
# Stage 4: Generating Answer
|
|
|
|
| 756 |
yield (
|
| 757 |
chat_history,
|
| 758 |
gr.update(visible=False),
|
|
|
|
| 766 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 767 |
</div>''', visible=True)
|
| 768 |
)
|
| 769 |
+
result = orchestrator.run_workflow(question=question_text, retriever=retriever)
|
| 770 |
answer = result["draft_answer"]
|
| 771 |
# Stage 5: Verifying Answer
|
|
|
|
| 772 |
yield (
|
| 773 |
chat_history,
|
| 774 |
gr.update(visible=False),
|
|
|
|
| 787 |
# Do not display verification to user, only use internally
|
| 788 |
chat_history.append({"role": "user", "content": question_text})
|
| 789 |
chat_history.append({"role": "assistant", "content": f"**Answer:**\n{answer}"})
|
|
|
|
| 790 |
session_state.value["last_documents"] = retriever.invoke(question_text)
|
|
|
|
|
|
|
| 791 |
yield (
|
| 792 |
chat_history,
|
| 793 |
gr.update(visible=True), # doc_context_display
|
|
|
|
| 800 |
<span id="processing-msg"></span>
|
| 801 |
<span id="processing-timer" style="opacity:0.8; margin-left:8px;"></span>
|
| 802 |
</div>''', visible=True)
|
| 803 |
+
)
|
|
|
|
|
|
|
| 804 |
yield (
|
| 805 |
chat_history,
|
| 806 |
gr.update(visible=True),
|
|
|
|
| 872 |
file_info_text += f"{source_file_path} not found\n"
|
| 873 |
if not copied_files:
|
| 874 |
return [], "", "Could not load example files"
|
| 875 |
+
return copied_files, question_text, file_info_text
|
| 876 |
+
|
|
|
|
|
|
|
| 877 |
example_dropdown.change(
|
| 878 |
fn=load_example,
|
| 879 |
inputs=[example_dropdown],
|
|
|
|
| 883 |
# HF Spaces sets SPACE_ID environment variable
|
| 884 |
is_hf_space = os.environ.get("SPACE_ID") is not None
|
| 885 |
|
| 886 |
+
demo.queue()
|
| 887 |
if is_hf_space:
|
| 888 |
# Hugging Face Spaces configuration
|
| 889 |
logger.info("Running on Hugging Face Spaces")
|
|
|
|
| 892 |
# Local development configuration
|
| 893 |
configured_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
| 894 |
server_port = _find_open_port(configured_port)
|
|
|
|
| 895 |
logger.info(f"Launching Gradio on port {server_port}")
|
| 896 |
logger.info(f"Access the app at: http://127.0.0.1:{server_port}")
|
|
|
|
| 897 |
demo.launch(server_name="127.0.0.1", server_port=server_port, share=False)
|
| 898 |
|
| 899 |
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
poppler-utils
|
requirements.txt
CHANGED
|
@@ -27,7 +27,7 @@ google-generativeai>=0.8.0
|
|
| 27 |
chromadb>=0.6.3
|
| 28 |
|
| 29 |
# Web framework
|
| 30 |
-
gradio>=5.
|
| 31 |
|
| 32 |
# Data processing
|
| 33 |
pandas>=2.1.4
|
|
|
|
| 27 |
chromadb>=0.6.3
|
| 28 |
|
| 29 |
# Web framework
|
| 30 |
+
gradio>=5.7.1
|
| 31 |
|
| 32 |
# Data processing
|
| 33 |
pandas>=2.1.4
|
search_engine/indexer.py
CHANGED
|
@@ -74,7 +74,7 @@ class EnsembleRetriever(BaseRetriever):
|
|
| 74 |
*,
|
| 75 |
run_manager: CallbackManagerForRetrieverRun = None
|
| 76 |
) -> List[Document]:
|
| 77 |
-
"""Retrieve and combine documents using weighted RRF, deduplicating charts by
|
| 78 |
logger.debug(f"[ENSEMBLE] Query: {query[:80]}...")
|
| 79 |
all_docs_with_scores = {}
|
| 80 |
retriever_names = ["BM25", "Vector"]
|
|
@@ -84,8 +84,8 @@ class EnsembleRetriever(BaseRetriever):
|
|
| 84 |
docs = retriever.invoke(query)
|
| 85 |
logger.debug(f"[ENSEMBLE] {retriever_name}: {len(docs)} docs (weight: {weight})")
|
| 86 |
for rank, doc in enumerate(docs):
|
| 87 |
-
# Deduplicate by
|
| 88 |
-
doc_key = (doc
|
| 89 |
rrf_score = weight / (rank + 1 + self.c)
|
| 90 |
if doc_key in all_docs_with_scores:
|
| 91 |
existing_doc, existing_score = all_docs_with_scores[doc_key]
|
|
|
|
| 74 |
*,
|
| 75 |
run_manager: CallbackManagerForRetrieverRun = None
|
| 76 |
) -> List[Document]:
|
| 77 |
+
"""Retrieve and combine documents using weighted RRF, deduplicating charts by doc_id and aggregating page numbers."""
|
| 78 |
logger.debug(f"[ENSEMBLE] Query: {query[:80]}...")
|
| 79 |
all_docs_with_scores = {}
|
| 80 |
retriever_names = ["BM25", "Vector"]
|
|
|
|
| 84 |
docs = retriever.invoke(query)
|
| 85 |
logger.debug(f"[ENSEMBLE] {retriever_name}: {len(docs)} docs (weight: {weight})")
|
| 86 |
for rank, doc in enumerate(docs):
|
| 87 |
+
# Deduplicate by doc_id only
|
| 88 |
+
doc_key = doc_id(doc)
|
| 89 |
rrf_score = weight / (rank + 1 + self.c)
|
| 90 |
if doc_key in all_docs_with_scores:
|
| 91 |
existing_doc, existing_score = all_docs_with_scores[doc_key]
|
tests/conftest.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Test fixtures and shared utilities for DocChat tests.
|
| 3 |
-
"""
|
| 4 |
-
import pytest
|
| 5 |
-
from unittest.mock import MagicMock
|
| 6 |
-
from langchain_core.documents import Document
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class FakeLLM:
|
| 10 |
-
"""Mock LLM for testing without API calls."""
|
| 11 |
-
|
| 12 |
-
def __init__(self, content: str = "Test response"):
|
| 13 |
-
self.content = content
|
| 14 |
-
self.last_prompt = None
|
| 15 |
-
self.invoke_count = 0
|
| 16 |
-
|
| 17 |
-
def invoke(self, prompt: str):
|
| 18 |
-
self.last_prompt = prompt
|
| 19 |
-
self.invoke_count += 1
|
| 20 |
-
return type("Response", (), {"content": self.content})()
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class FakeRetriever:
|
| 24 |
-
"""Mock retriever for testing without vector store."""
|
| 25 |
-
|
| 26 |
-
def __init__(self, documents: list = None):
|
| 27 |
-
self.documents = documents or []
|
| 28 |
-
self.invoke_count = 0
|
| 29 |
-
self.last_query = None
|
| 30 |
-
|
| 31 |
-
def invoke(self, query: str):
|
| 32 |
-
self.last_query = query
|
| 33 |
-
self.invoke_count += 1
|
| 34 |
-
return self.documents
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
@pytest.fixture
|
| 38 |
-
def sample_documents():
|
| 39 |
-
"""Create sample documents for testing."""
|
| 40 |
-
return [
|
| 41 |
-
Document(
|
| 42 |
-
page_content="The data center in Singapore achieved a PUE of 1.12 in 2022.",
|
| 43 |
-
metadata={"source": "test.pdf", "page": 1}
|
| 44 |
-
),
|
| 45 |
-
Document(
|
| 46 |
-
page_content="Carbon-free energy in Asia Pacific reached 45% in 2023.",
|
| 47 |
-
metadata={"source": "test.pdf", "page": 2}
|
| 48 |
-
),
|
| 49 |
-
Document(
|
| 50 |
-
page_content="DeepSeek-R1 outperformed o1-mini on coding benchmarks.",
|
| 51 |
-
metadata={"source": "deepseek.pdf", "page": 1}
|
| 52 |
-
),
|
| 53 |
-
]
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
@pytest.fixture
|
| 57 |
-
def fake_llm():
|
| 58 |
-
"""Create a fake LLM for testing."""
|
| 59 |
-
return FakeLLM("This is a test response.")
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
@pytest.fixture
|
| 63 |
-
def fake_retriever(sample_documents):
|
| 64 |
-
"""Create a fake retriever with sample documents."""
|
| 65 |
-
return FakeRetriever(sample_documents)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
@pytest.fixture
|
| 69 |
-
def empty_retriever():
|
| 70 |
-
"""Create a fake retriever that returns no documents."""
|
| 71 |
-
return FakeRetriever([])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_accuracy_verifier.py
DELETED
|
@@ -1,110 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Tests for the VerificationAgent.
|
| 3 |
-
"""
|
| 4 |
-
import pytest
|
| 5 |
-
from unittest.mock import MagicMock, patch
|
| 6 |
-
from langchain_core.documents import Document
|
| 7 |
-
|
| 8 |
-
# Import after setting up mocks to avoid API key validation
|
| 9 |
-
import sys
|
| 10 |
-
sys.path.insert(0, '.')
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class TestVerificationAgent:
|
| 14 |
-
"""Test suite for VerificationAgent."""
|
| 15 |
-
|
| 16 |
-
@pytest.fixture
|
| 17 |
-
def mock_parameters(self, monkeypatch):
|
| 18 |
-
"""Mock parameters to avoid API key requirement."""
|
| 19 |
-
monkeypatch.setenv("GOOGLE_API_KEY", "test_key_for_testing")
|
| 20 |
-
|
| 21 |
-
@pytest.fixture
|
| 22 |
-
def accuracy_verifier(self, mock_parameters, fake_llm):
|
| 23 |
-
"""Create a VerificationAgent with mocked LLM."""
|
| 24 |
-
from intelligence.accuracy_verifier import VerificationAgent
|
| 25 |
-
return VerificationAgent(llm=fake_llm)
|
| 26 |
-
|
| 27 |
-
def test_check_with_supported_answer(self, accuracy_verifier, sample_documents):
|
| 28 |
-
"""Test verification with an answer supported by documents."""
|
| 29 |
-
# Configure the fake LLM to return a supported response
|
| 30 |
-
accuracy_verifier.llm.content = """
|
| 31 |
-
Supported: YES
|
| 32 |
-
Unsupported Claims: []
|
| 33 |
-
Contradictions: []
|
| 34 |
-
Relevant: YES
|
| 35 |
-
Additional Details: The answer is well-supported by the context.
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
result = accuracy_verifier.check(
|
| 39 |
-
answer="The PUE in Singapore was 1.12 in 2022.",
|
| 40 |
-
documents=sample_documents
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
assert "verification_report" in result
|
| 44 |
-
assert "Supported: YES" in result["verification_report"]
|
| 45 |
-
assert "context_used" in result
|
| 46 |
-
|
| 47 |
-
def test_check_with_unsupported_answer(self, accuracy_verifier, sample_documents):
|
| 48 |
-
"""Test verification with an unsupported answer."""
|
| 49 |
-
accuracy_verifier.llm.content = """
|
| 50 |
-
Supported: NO
|
| 51 |
-
Unsupported Claims: [The PUE was 1.5]
|
| 52 |
-
Contradictions: []
|
| 53 |
-
Relevant: YES
|
| 54 |
-
Additional Details: The claimed PUE value is not in the context.
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
result = accuracy_verifier.check(
|
| 58 |
-
answer="The PUE in Singapore was 1.5 in 2022.",
|
| 59 |
-
documents=sample_documents
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
assert "Supported: NO" in result["verification_report"]
|
| 63 |
-
|
| 64 |
-
def test_parse_verification_response_valid(self, accuracy_verifier):
|
| 65 |
-
"""Test parsing a valid verification response."""
|
| 66 |
-
response = """
|
| 67 |
-
Supported: YES
|
| 68 |
-
Unsupported Claims: []
|
| 69 |
-
Contradictions: []
|
| 70 |
-
Relevant: YES
|
| 71 |
-
Additional Details: All claims verified.
|
| 72 |
-
"""
|
| 73 |
-
|
| 74 |
-
parsed = accuracy_verifier.parse_verification_response(response)
|
| 75 |
-
|
| 76 |
-
assert parsed["Supported"] == "YES"
|
| 77 |
-
assert parsed["Relevant"] == "YES"
|
| 78 |
-
assert parsed["Unsupported Claims"] == []
|
| 79 |
-
|
| 80 |
-
def test_parse_verification_response_with_claims(self, accuracy_verifier):
|
| 81 |
-
"""Test parsing response with unsupported claims."""
|
| 82 |
-
response = """
|
| 83 |
-
Supported: NO
|
| 84 |
-
Unsupported Claims: [claim1, claim2]
|
| 85 |
-
Contradictions: [contradiction1]
|
| 86 |
-
Relevant: YES
|
| 87 |
-
Additional Details: Multiple issues found.
|
| 88 |
-
"""
|
| 89 |
-
|
| 90 |
-
parsed = accuracy_verifier.parse_verification_response(response)
|
| 91 |
-
|
| 92 |
-
assert parsed["Supported"] == "NO"
|
| 93 |
-
assert len(parsed["Unsupported Claims"]) == 2
|
| 94 |
-
assert len(parsed["Contradictions"]) == 1
|
| 95 |
-
|
| 96 |
-
def test_format_verification_report(self, accuracy_verifier):
|
| 97 |
-
"""Test formatting a verification report."""
|
| 98 |
-
verification = {
|
| 99 |
-
"Supported": "YES",
|
| 100 |
-
"Unsupported Claims": [],
|
| 101 |
-
"Contradictions": [],
|
| 102 |
-
"Relevant": "YES",
|
| 103 |
-
"Additional Details": "Well verified."
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
report = accuracy_verifier.format_verification_report(verification)
|
| 107 |
-
|
| 108 |
-
assert "**Supported:** YES" in report
|
| 109 |
-
assert "**Relevant:** YES" in report
|
| 110 |
-
assert "**Unsupported Claims:** None" in report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_context_validator.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Tests for the RelevanceChecker.
|
| 3 |
-
"""
|
| 4 |
-
import pytest
|
| 5 |
-
from unittest.mock import MagicMock
|
| 6 |
-
from langchain_core.documents import Document
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
sys.path.insert(0, '.')
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class TestRelevanceChecker:
|
| 13 |
-
"""Test suite for RelevanceChecker."""
|
| 14 |
-
|
| 15 |
-
@pytest.fixture
|
| 16 |
-
def mock_parameters(self, monkeypatch):
|
| 17 |
-
"""Mock parameters to avoid API key requirement."""
|
| 18 |
-
monkeypatch.setenv("GOOGLE_API_KEY", "test_key_for_testing")
|
| 19 |
-
|
| 20 |
-
@pytest.fixture
|
| 21 |
-
def context_validator(self, mock_parameters, fake_llm):
|
| 22 |
-
"""Create a RelevanceChecker with mocked LLM."""
|
| 23 |
-
from intelligence.context_validator import RelevanceChecker
|
| 24 |
-
checker = RelevanceChecker()
|
| 25 |
-
checker.llm = fake_llm
|
| 26 |
-
return checker
|
| 27 |
-
|
| 28 |
-
def test_check_can_answer(self, context_validator, fake_retriever):
|
| 29 |
-
"""Test when documents can fully answer the question."""
|
| 30 |
-
context_validator.llm.content = "CAN_ANSWER"
|
| 31 |
-
|
| 32 |
-
result = context_validator.check(
|
| 33 |
-
question="What is the PUE in Singapore?",
|
| 34 |
-
retriever=fake_retriever,
|
| 35 |
-
k=3
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
assert result == "CAN_ANSWER"
|
| 39 |
-
assert fake_retriever.invoke_count == 1
|
| 40 |
-
|
| 41 |
-
def test_check_partial_match(self, context_validator, fake_retriever):
|
| 42 |
-
"""Test when documents partially match the question."""
|
| 43 |
-
context_validator.llm.content = "PARTIAL"
|
| 44 |
-
|
| 45 |
-
result = context_validator.check(
|
| 46 |
-
question="What is the historical trend of PUE?",
|
| 47 |
-
retriever=fake_retriever,
|
| 48 |
-
k=3
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
assert result == "PARTIAL"
|
| 52 |
-
|
| 53 |
-
def test_check_no_match(self, context_validator, fake_retriever):
|
| 54 |
-
"""Test when documents don't match the question."""
|
| 55 |
-
context_validator.llm.content = "NO_MATCH"
|
| 56 |
-
|
| 57 |
-
result = context_validator.check(
|
| 58 |
-
question="What is the weather in Paris?",
|
| 59 |
-
retriever=fake_retriever,
|
| 60 |
-
k=3
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
assert result == "NO_MATCH"
|
| 64 |
-
|
| 65 |
-
def test_check_empty_question(self, context_validator, fake_retriever):
|
| 66 |
-
"""Test with empty question returns NO_MATCH."""
|
| 67 |
-
result = context_validator.check(
|
| 68 |
-
question="",
|
| 69 |
-
retriever=fake_retriever,
|
| 70 |
-
k=3
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
assert result == "NO_MATCH"
|
| 74 |
-
|
| 75 |
-
def test_check_empty_retriever_results(self, context_validator, empty_retriever):
|
| 76 |
-
"""Test when retriever returns no documents."""
|
| 77 |
-
result = context_validator.check(
|
| 78 |
-
question="Any question",
|
| 79 |
-
retriever=empty_retriever,
|
| 80 |
-
k=3
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
assert result == "NO_MATCH"
|
| 84 |
-
|
| 85 |
-
def test_check_invalid_llm_response(self, context_validator, fake_retriever):
|
| 86 |
-
"""Test when LLM returns invalid response."""
|
| 87 |
-
context_validator.llm.content = "INVALID_LABEL"
|
| 88 |
-
|
| 89 |
-
result = context_validator.check(
|
| 90 |
-
question="What is the PUE?",
|
| 91 |
-
retriever=fake_retriever,
|
| 92 |
-
k=3
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
assert result == "NO_MATCH"
|
| 96 |
-
|
| 97 |
-
def test_check_retriever_exception(self, context_validator):
|
| 98 |
-
"""Test when retriever throws an exception."""
|
| 99 |
-
failing_retriever = MagicMock()
|
| 100 |
-
failing_retriever.invoke.side_effect = Exception("Connection error")
|
| 101 |
-
|
| 102 |
-
result = context_validator.check(
|
| 103 |
-
question="Any question",
|
| 104 |
-
retriever=failing_retriever,
|
| 105 |
-
k=3
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
assert result == "NO_MATCH"
|
| 109 |
-
|
| 110 |
-
def test_check_invalid_k_value(self, context_validator, fake_retriever):
|
| 111 |
-
"""Test with invalid k value defaults to 3."""
|
| 112 |
-
context_validator.llm.content = "CAN_ANSWER"
|
| 113 |
-
|
| 114 |
-
result = context_validator.check(
|
| 115 |
-
question="What is the PUE?",
|
| 116 |
-
retriever=fake_retriever,
|
| 117 |
-
k=-1
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
assert result == "CAN_ANSWER"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_knowledge_synthesizer.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
import unittest
|
| 2 |
-
|
| 3 |
-
try:
|
| 4 |
-
from langchain_core.documents import Document
|
| 5 |
-
from intelligence.knowledge_synthesizer import ResearchAgent
|
| 6 |
-
LANGCHAIN_AVAILABLE = True
|
| 7 |
-
except ImportError:
|
| 8 |
-
Document = None # type: ignore
|
| 9 |
-
ResearchAgent = None # type: ignore
|
| 10 |
-
LANGCHAIN_AVAILABLE = False
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class FakeLLM:
|
| 14 |
-
"""Simple stand-in for ChatGoogleGenerativeAI to avoid network calls."""
|
| 15 |
-
|
| 16 |
-
def __init__(self, content: str) -> None:
|
| 17 |
-
self.content = content
|
| 18 |
-
self.last_prompt = None
|
| 19 |
-
|
| 20 |
-
def invoke(self, prompt: str):
|
| 21 |
-
self.last_prompt = prompt
|
| 22 |
-
return type("Resp", (), {"content": self.content})
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@unittest.skipUnless(LANGCHAIN_AVAILABLE, "langchain not installed in this environment")
|
| 26 |
-
class ResearchAgentTests(unittest.TestCase):
|
| 27 |
-
def test_generate_returns_stubbed_content_with_citations(self):
|
| 28 |
-
docs = [
|
| 29 |
-
Document(page_content="Alpha text", metadata={"id": "a1"}),
|
| 30 |
-
Document(page_content="Beta text", metadata={"source": "s1"}),
|
| 31 |
-
]
|
| 32 |
-
llm = FakeLLM("Answer about alpha")
|
| 33 |
-
agent = ResearchAgent(llm=llm, top_k=1, max_context_chars=200)
|
| 34 |
-
|
| 35 |
-
result = agent.generate("What is alpha?", docs)
|
| 36 |
-
|
| 37 |
-
self.assertEqual(result["draft_answer"], "Answer about alpha")
|
| 38 |
-
self.assertIn("Alpha text", llm.last_prompt)
|
| 39 |
-
|
| 40 |
-
def test_generate_handles_no_documents(self):
|
| 41 |
-
llm = FakeLLM("unused")
|
| 42 |
-
agent = ResearchAgent(llm=llm)
|
| 43 |
-
|
| 44 |
-
result = agent.generate("Any question", [])
|
| 45 |
-
|
| 46 |
-
self.assertIn("could not find supporting documents", result["draft_answer"])
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
if __name__ == "__main__":
|
| 50 |
-
unittest.main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_visual_extraction.py
DELETED
|
@@ -1,169 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Test script for Gemini Vision chart extraction.
|
| 3 |
-
|
| 4 |
-
This script demonstrates how to use the chart extraction feature
|
| 5 |
-
and validates that it's working correctly.
|
| 6 |
-
"""
|
| 7 |
-
import logging
|
| 8 |
-
import os
|
| 9 |
-
import sys
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
|
| 12 |
-
# Add parent directory to path
|
| 13 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
-
|
| 15 |
-
from content_analyzer.document_parser import DocumentProcessor
|
| 16 |
-
from configuration.parameters import parameters
|
| 17 |
-
|
| 18 |
-
# Configure logging
|
| 19 |
-
logging.basicConfig(
|
| 20 |
-
level=logging.INFO,
|
| 21 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 22 |
-
)
|
| 23 |
-
logger = logging.getLogger(__name__)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def test_chart_extraction():
|
| 27 |
-
"""Test chart extraction on a sample PDF with charts."""
|
| 28 |
-
|
| 29 |
-
logger.info("=" * 60)
|
| 30 |
-
logger.info("Testing Gemini Vision Chart Extraction")
|
| 31 |
-
logger.info("=" * 60)
|
| 32 |
-
|
| 33 |
-
# Check if chart extraction is enabled
|
| 34 |
-
if not parameters.ENABLE_CHART_EXTRACTION:
|
| 35 |
-
logger.warning("?? Chart extraction is DISABLED")
|
| 36 |
-
logger.info("Enable it by setting ENABLE_CHART_EXTRACTION=true in .env")
|
| 37 |
-
return
|
| 38 |
-
|
| 39 |
-
logger.info(f"? Chart extraction enabled")
|
| 40 |
-
logger.info(f"?? Using model: {parameters.CHART_VISION_MODEL}")
|
| 41 |
-
logger.info(f"?? Max tokens: {parameters.CHART_MAX_TOKENS}")
|
| 42 |
-
|
| 43 |
-
# Initialize processor
|
| 44 |
-
try:
|
| 45 |
-
processor = DocumentProcessor()
|
| 46 |
-
logger.info("? DocumentProcessor initialized")
|
| 47 |
-
|
| 48 |
-
if processor.gemini_client:
|
| 49 |
-
logger.info("? Gemini Vision client ready")
|
| 50 |
-
else:
|
| 51 |
-
logger.error("? Gemini Vision client not initialized")
|
| 52 |
-
return
|
| 53 |
-
|
| 54 |
-
except Exception as e:
|
| 55 |
-
logger.error(f"? Failed to initialize processor: {e}")
|
| 56 |
-
return
|
| 57 |
-
|
| 58 |
-
# Test with example PDF (if exists)
|
| 59 |
-
test_files = [
|
| 60 |
-
"examples/google-2024-environmental-report.pdf",
|
| 61 |
-
"examples/deppseek.pdf",
|
| 62 |
-
"test/sample_with_charts.pdf"
|
| 63 |
-
]
|
| 64 |
-
|
| 65 |
-
found_file = None
|
| 66 |
-
for test_file in test_files:
|
| 67 |
-
if os.path.exists(test_file):
|
| 68 |
-
found_file = test_file
|
| 69 |
-
break
|
| 70 |
-
|
| 71 |
-
if not found_file:
|
| 72 |
-
logger.warning("?? No test PDF files found")
|
| 73 |
-
logger.info("Available test files:")
|
| 74 |
-
for tf in test_files:
|
| 75 |
-
logger.info(f" - {tf}")
|
| 76 |
-
logger.info("\nTo test manually:")
|
| 77 |
-
logger.info("1. Place a PDF with charts in one of the above locations")
|
| 78 |
-
logger.info("2. Run this script again")
|
| 79 |
-
return
|
| 80 |
-
|
| 81 |
-
logger.info(f"\n?? Processing test file: {found_file}")
|
| 82 |
-
|
| 83 |
-
# Create mock file object
|
| 84 |
-
class MockFile:
|
| 85 |
-
def __init__(self, path):
|
| 86 |
-
self.name = path
|
| 87 |
-
self.size = os.path.getsize(path)
|
| 88 |
-
|
| 89 |
-
try:
|
| 90 |
-
# Process the file
|
| 91 |
-
mock_file = MockFile(found_file)
|
| 92 |
-
chunks = processor.process([mock_file])
|
| 93 |
-
|
| 94 |
-
logger.info(f"\n? Processing complete!")
|
| 95 |
-
logger.info(f"?? Total chunks extracted: {len(chunks)}")
|
| 96 |
-
|
| 97 |
-
# Count chart chunks
|
| 98 |
-
chart_chunks = [c for c in chunks if c.metadata.get("type") == "chart"]
|
| 99 |
-
text_chunks = [c for c in chunks if c.metadata.get("type") != "chart"]
|
| 100 |
-
|
| 101 |
-
logger.info(f"?? Chart chunks: {len(chart_chunks)}")
|
| 102 |
-
logger.info(f"?? Text chunks: {len(text_chunks)}")
|
| 103 |
-
|
| 104 |
-
# Display chart analyses
|
| 105 |
-
if chart_chunks:
|
| 106 |
-
logger.info(f"\n{'=' * 60}")
|
| 107 |
-
logger.info("?? CHART ANALYSES EXTRACTED:")
|
| 108 |
-
logger.info('=' * 60)
|
| 109 |
-
|
| 110 |
-
for i, chunk in enumerate(chart_chunks, 1):
|
| 111 |
-
logger.info(f"\n--- Chart {i} ---")
|
| 112 |
-
logger.info(f"Page: {chunk.metadata.get('page')}")
|
| 113 |
-
logger.info(f"Preview: {chunk.page_content[:200]}...")
|
| 114 |
-
logger.info("")
|
| 115 |
-
else:
|
| 116 |
-
logger.info("\n?? No charts detected in this document")
|
| 117 |
-
logger.info("This could mean:")
|
| 118 |
-
logger.info(" - Document contains no charts")
|
| 119 |
-
logger.info(" - Charts are embedded as tables (already extracted)")
|
| 120 |
-
logger.info(" - Charts are too complex for detection")
|
| 121 |
-
|
| 122 |
-
logger.info(f"\n{'=' * 60}")
|
| 123 |
-
logger.info("? Test completed successfully!")
|
| 124 |
-
logger.info('=' * 60)
|
| 125 |
-
|
| 126 |
-
except Exception as e:
|
| 127 |
-
logger.error(f"? Test failed: {e}", exc_info=True)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def test_api_connection():
|
| 131 |
-
"""Test Gemini API connection."""
|
| 132 |
-
logger.info("\n" + "=" * 60)
|
| 133 |
-
logger.info("Testing Gemini API Connection")
|
| 134 |
-
logger.info("=" * 60)
|
| 135 |
-
|
| 136 |
-
try:
|
| 137 |
-
import google.generativeai as genai
|
| 138 |
-
from PIL import Image
|
| 139 |
-
import io
|
| 140 |
-
|
| 141 |
-
genai.configure(api_key=parameters.GOOGLE_API_KEY)
|
| 142 |
-
model = genai.GenerativeModel(parameters.CHART_VISION_MODEL)
|
| 143 |
-
|
| 144 |
-
logger.info("? Gemini client initialized")
|
| 145 |
-
|
| 146 |
-
# Test with a simple text prompt
|
| 147 |
-
response = model.generate_content("Hello! Can you respond with 'API Working'?")
|
| 148 |
-
logger.info(f"? API Response: {response.text}")
|
| 149 |
-
|
| 150 |
-
logger.info("? Gemini API connection successful!")
|
| 151 |
-
|
| 152 |
-
except ImportError as e:
|
| 153 |
-
logger.error(f"? Missing dependency: {e}")
|
| 154 |
-
logger.info("Install with: pip install google-generativeai Pillow")
|
| 155 |
-
except Exception as e:
|
| 156 |
-
logger.error(f"? API test failed: {e}")
|
| 157 |
-
logger.info("Check your GOOGLE_API_KEY in .env file")
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
if __name__ == "__main__":
|
| 161 |
-
print("\n?? SmartDoc AI - Chart Extraction Test Suite\n")
|
| 162 |
-
|
| 163 |
-
# Test 1: API Connection
|
| 164 |
-
test_api_connection()
|
| 165 |
-
|
| 166 |
-
# Test 2: Chart Extraction
|
| 167 |
-
test_chart_extraction()
|
| 168 |
-
|
| 169 |
-
print("\n? All tests completed!\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/data_level0.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c8fe3c8d74ae8a7762e6f389543f0f2c53e6127832955b377ed768f8759db70d
|
| 3 |
-
size 16165996
|
|
|
|
|
|
|
|
|
|
|
|
vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/header.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:059abd7ab166731c13bd8dc4dc0724104918b450e9625ca4bc9f27ed0016170e
|
| 3 |
-
size 100
|
|
|
|
|
|
|
|
|
|
|
|
vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/index_metadata.pickle
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:bc43535869cc54fbd80a6a47dac2fd0b07f4eeb0c028b5c96026b6cdc271832b
|
| 3 |
-
size 463184
|
|
|
|
|
|
|
|
|
|
|
|
vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/length.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:fa6bfa281c8fe4e4977d5382b077dee4a3c4e5c750985cdf3d3660a6f92dab67
|
| 3 |
-
size 20132
|
|
|
|
|
|
|
|
|
|
|
|
vector_store/33eccd62-a7fc-4b0d-a118-02552f5cad42/link_lists.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ffcd2c7be0de4c70919af69080b33cbd5c7487471058b2a70ee5bf95ab86ea00
|
| 3 |
-
size 42436
|
|
|
|
|
|
|
|
|
|
|
|