Hammad712 commited on
Commit
d2654d6
·
0 Parent(s):

Added ingestion code

Browse files
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ venv/
3
+ .env
4
+ .git
5
+ .gitignore
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ .DS_Store
.env.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Example env vars. Do NOT commit real secrets.
2
+ VOYAGE_API_KEY=
3
+ QDRANT_API_KEY=
4
+ QDRANT_URL=
5
+ GOOGLE_API_KEY=
6
+ DRY_RUN=1
.gitignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Python ---
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # --- Virtual Environments ---
7
+ # Common names for virtual envs
8
+ venv/
9
+ env/
10
+ .env/
11
+ .venv/
12
+
13
+ # --- Environment Variables (CRITICAL) ---
14
+ # Never commit secrets or api keys
15
+ .env
16
+ .env.local
17
+ .env.development.local
18
+ .env.test.local
19
+ .env.production.local
20
+
21
+ # --- Distribution / Packaging ---
22
+ dist/
23
+ build/
24
+ *.egg-info/
25
+
26
+ # --- Testing & Coverage ---
27
+ .pytest_cache/
28
+ .coverage
29
+ htmlcov/
30
+ coverage.xml
31
+
32
+ # --- Jupyter Notebooks (if applicable) ---
33
+ .ipynb_checkpoints
34
+
35
+ # --- IDE / Editors ---
36
+ .vscode/
37
+ .idea/
38
+ *.swp
39
+
40
+ # --- Databases ---
41
+ # Ignore local SQLite databases so you don't overwrite prod data or commit binary blobs
42
+ *.sqlite3
43
+ *.db
44
+
45
+ # --- Logs ---
46
+ *.log
47
+ logs/
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official lightweight Python image
2
+ FROM python:3.10-slim
3
+
4
+ # Set environment variables
5
+ # PYTHONDONTWRITEBYTECODE: Prevents Python from writing pyc files to disc
6
+ # PYTHONUNBUFFERED: Ensures logs are flushed immediately
7
+ ENV PYTHONDONTWRITEBYTECODE=1 \
8
+ PYTHONUNBUFFERED=1 \
9
+ # Point Hugging Face cache to a writable directory
10
+ HF_HOME=/app/.cache/huggingface
11
+
12
+ # Set the working directory
13
+ WORKDIR /app
14
+
15
+ # Create a non-root user with a specific UID (1000) for security & HF compatibility
16
+ # and give them ownership of the /app directory
17
+ RUN useradd -m -u 1000 user && \
18
+ chown -R user:user /app
19
+
20
+ # Switch to the non-root user
21
+ USER user
22
+
23
+ # Set up the PATH to include the user's local bin (where pip installs tools)
24
+ ENV PATH="/home/user/.local/bin:$PATH"
25
+
26
+ # Copy the requirements file first to leverage Docker cache
27
+ COPY --chown=user:user requirements.txt .
28
+
29
+ # Install dependencies
30
+ RUN pip install --no-cache-dir --upgrade pip && \
31
+ pip install --no-cache-dir -r requirements.txt
32
+
33
+ # Copy the rest of the application code
34
+ COPY --chown=user:user . .
35
+
36
+ # Expose port 7860 (Required for Hugging Face Spaces)
37
+ EXPOSE 7860
38
+
39
+ # Command to run the application
40
+ # Note: Ensure your main file is named 'main.py' and the app instance is 'app'
41
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PDF Extraction FastAPI
2
+ ======================
3
+
4
+ This repository provides a FastAPI wrapper around a PDF visual/text extraction pipeline.
5
+
6
+ Structure
7
+ - `app/` — application package
8
+ - `main.py` — FastAPI app and endpoints
9
+ - `pipeline.py` — refactored pipeline logic (supports `dry_run=True` to avoid external APIs)
10
+ - `qdrant_ingest.py` — markdown chunking + ingest placeholder
11
+ - `utils.py` — helpers
12
+
13
+ Quickstart
14
+ 1. Install system dependency for `pdf2image` (Debian/Ubuntu):
15
+
16
+ ```bash
17
+ sudo apt-get update && sudo apt-get install -y poppler-utils
18
+ ```
19
+
20
+ 2. Create virtualenv and install Python deps:
21
+
22
+ ```bash
23
+ python3 -m venv .venv
24
+ source .venv/bin/activate
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ 3. Run the app:
29
+
30
+ ```bash
31
+ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
32
+ ```
33
+
34
+ 4. Use the `/process` endpoint to upload a PDF (Swagger UI at `/docs`). For quick local testing use `dry_run=true`.
35
+
36
+ Notes
37
+ - The code includes only a simulated/dry-run mode for model calls — enable real model usage by integrating your API keys and adding real calls into `app/pipeline.py` where marked.
38
+ - To ingest into Qdrant, provide your credentials and implement vectorization in `app/qdrant_ingest.py` (placeholder included).
app/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """App package for PDF extraction service."""
2
+
3
+ __all__ = ["main", "pipeline", "qdrant_ingest", "schemas", "utils"]
app/core/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+
8
+ def load_config() -> Dict[str, str]:
9
+ return {
10
+ "GOOGLE_API_KEY": os.environ.get("GOOGLE_API_KEY", ""),
11
+ "VOYAGE_API_KEY": os.environ.get("VOYAGE_API_KEY", ""),
12
+ "QDRANT_URL": os.environ.get("QDRANT_URL", ""),
13
+ "QDRANT_API_KEY": os.environ.get("QDRANT_API_KEY", ""),
14
+ "QDRANT_COLLECTION": os.environ.get("QDRANT_COLLECTION", "mercurygse"),
15
+ "QDRANT_BATCH_SIZE": os.environ.get("QDRANT_BATCH_SIZE", "256"),
16
+ }
app/main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from .routes import router as api_router
3
+ from .services import model_client
4
+ from .core import config as core_config
5
+ import os
6
+ import logging
7
+
8
+ logger = logging.getLogger("pdf_extraction")
9
+ if not logger.handlers:
10
+ # simple default handler
11
+ h = logging.StreamHandler()
12
+ fmt = logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")
13
+ h.setFormatter(fmt)
14
+ logger.addHandler(h)
15
+ logger.setLevel(logging.INFO)
16
+
17
+ app = FastAPI(title="PDF Extraction Service")
18
+
19
+
20
+ @app.on_event("startup")
21
+ def startup_event():
22
+ # initialize clients from environment if available
23
+ cfg = core_config.load_config()
24
+ # Log presence (do not print secrets)
25
+ logger.info("GOOGLE_API_KEY set: %s", bool(cfg.get("GOOGLE_API_KEY")))
26
+ logger.info("VOYAGE_API_KEY set: %s", bool(cfg.get("VOYAGE_API_KEY")))
27
+ logger.info("QDRANT_URL set: %s", bool(cfg.get("QDRANT_URL")))
28
+ logger.info("QDRANT_API_KEY set: %s", bool(cfg.get("QDRANT_API_KEY")))
29
+
30
+ genai = model_client.init_genai_client(cfg.get("GOOGLE_API_KEY"))
31
+ if genai:
32
+ logger.info("GenAI client initialized successfully")
33
+ else:
34
+ logger.warning("GenAI client not initialized - missing key or import failure")
35
+
36
+ emb = model_client.init_embeddings(cfg.get("VOYAGE_API_KEY"))
37
+ if emb:
38
+ logger.info("Embeddings client initialized successfully")
39
+ else:
40
+ logger.warning("Embeddings client not initialized - missing key or import failure")
41
+
42
+ qc = model_client.init_qdrant_client(cfg.get("QDRANT_URL"), cfg.get("QDRANT_API_KEY"))
43
+ if qc:
44
+ logger.info("Qdrant client initialized successfully")
45
+ else:
46
+ logger.warning("Qdrant client not initialized - missing URL/API key or import failure")
47
+
48
+
49
+ app.include_router(api_router)
50
+
51
+
52
+ @app.get("/", tags=["root"])
53
+ def read_root():
54
+ return {"message": "Welcome to the PDF Extraction Service"}
55
+
56
+
app/routes/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ router = APIRouter()
4
+
5
+ from . import process, health # noqa: E402,F401
6
+
7
+ router.include_router(process.router, prefix="/process", tags=["process"])
8
+ router.include_router(health.router, prefix="/health", tags=["health"])
app/routes/health.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ router = APIRouter()
4
+
5
+
6
+ @router.get("/live")
7
+ async def live():
8
+ return {"status": "ok"}
9
+
10
+
11
+ @router.get("/ready")
12
+ async def ready():
13
+ return {"status": "ready"}
app/routes/ingest.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ import os
3
+ from typing import Optional
4
+ from ..services.qdrant_service import chunk_markdown_by_page, ingest_chunks_into_qdrant
5
+
6
+ router = APIRouter()
7
+
8
+
9
+ @router.post("/")
10
+ async def ingest(report_path: str, collection: Optional[str] = "mercurygse"):
11
+ if not os.path.exists(report_path):
12
+ raise HTTPException(status_code=404, detail='Report not found')
13
+ chunks = chunk_markdown_by_page(report_path)
14
+ res = ingest_chunks_into_qdrant(chunks, collection_name=collection)
15
+ return {"chunks": len(chunks), "ingest_result": res}
app/routes/process.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, UploadFile, File, HTTPException
2
+ from fastapi.responses import FileResponse, StreamingResponse
3
+ import os
4
+ import json
5
+ import uuid
6
+ import queue
7
+ import threading
8
+ from typing import Optional
9
+ from ..utils import save_upload_file_tmp
10
+ from ..services.pipeline_service import run_pipeline
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ router = APIRouter()
16
+ @router.post("/pdf/stream")
17
+ async def process_pdf_stream(file: UploadFile = File(...), max_pages: Optional[int] = None):
18
+ if not file.filename.lower().endswith('.pdf'):
19
+ raise HTTPException(status_code=400, detail='Only PDF uploads are supported')
20
+ tmp_path, filename = save_upload_file_tmp(file)
21
+
22
+ q = queue.Queue()
23
+ job_id = str(uuid.uuid4())
24
+ logger.info("Received upload %s -> %s; job=%s", file.filename, tmp_path, job_id)
25
+
26
+ def progress_hook(ev: dict):
27
+ ev_out = {"job_id": job_id, **ev}
28
+ q.put(ev_out)
29
+
30
+ def worker():
31
+ try:
32
+ run_pipeline(tmp_path, max_pages=max_pages, progress_hook=progress_hook, doc_id=job_id, original_filename=filename)
33
+ q.put({"job_id": job_id, "event": "worker_done"})
34
+ except Exception as e:
35
+ q.put({"job_id": job_id, "event": "error", "error": str(e)})
36
+
37
+ thread = threading.Thread(target=worker, daemon=True)
38
+ thread.start()
39
+
40
+ def event_generator():
41
+ try:
42
+ while True:
43
+ try:
44
+ ev = q.get(timeout=0.5)
45
+ except Exception:
46
+ if not thread.is_alive():
47
+ break
48
+ continue
49
+ # SSE format
50
+ s = f"data: {json.dumps(ev)}\n\n"
51
+ yield s.encode('utf-8')
52
+ # drain any remaining events
53
+ while not q.empty():
54
+ ev = q.get()
55
+ s = f"data: {json.dumps(ev)}\n\n"
56
+ yield s.encode('utf-8')
57
+ finally:
58
+ try:
59
+ if os.path.exists(tmp_path):
60
+ os.remove(tmp_path)
61
+ except Exception:
62
+ pass
63
+
64
+ return StreamingResponse(event_generator(), media_type='text/event-stream')
65
+
66
+
67
+ @router.get("/report")
68
+ async def download_report(path: str):
69
+ if not os.path.exists(path):
70
+ raise HTTPException(status_code=404, detail='Report not found')
71
+ return FileResponse(path, media_type='text/markdown', filename=os.path.basename(path))
app/schemas/models.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional
3
+
4
+
5
+ class RouterOutput(BaseModel):
6
+ route: str
7
+ contains_visual: bool
8
+ visual_types: List[str]
9
+ reason: str = Field(..., min_length=8)
10
+ confidence: float = Field(..., ge=0.0, le=1.0)
11
+
12
+
13
+ class KeyComponent(BaseModel):
14
+ name: str
15
+ description: str
16
+ extraction_confidence: Optional[float] = Field(None, ge=0.0, le=1.0)
17
+
18
+
19
+ class DiagramExtraction(BaseModel):
20
+ schema_id: str = Field("diagram_v1")
21
+ pdf_page: int
22
+ printed_page: Optional[str]
23
+ title: str
24
+ category: str
25
+ summary: str
26
+ key_components: List[KeyComponent] = Field(default_factory=list)
27
+ relationships: str
28
+ raw_text: str
29
+ extraction_confidence: float = Field(..., ge=0.0, le=1.0)
30
+
31
+
32
+ class SimpleExtraction(BaseModel):
33
+ schema_id: str = Field("simple_v1")
34
+ pdf_page: int
35
+ printed_page: Optional[str]
36
+ topic: str
37
+ summary: str
38
+ content_markdown: str
39
+ important_dates_or_entities: List[str] = Field(default_factory=list)
40
+ extraction_confidence: float = Field(..., ge=0.0, le=1.0)
app/services/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Service package initialiser.
2
+
3
+ Avoid importing submodules at package import time to prevent circular
4
+ imports (modules should import specific submodules directly where needed).
5
+ """
6
+
7
+ __all__ = ["model_client", "qdrant_service", "pipeline_service"]
app/services/model_client.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model client and embedding factories.
3
+ Read API keys from environment variables.
4
+ """
5
+ import os
6
+ from typing import Optional
7
+ from ..core import config as core_config
8
+
9
+ genai_client = None
10
+ embeddings = None
11
+ qdrant_client = None
12
+
13
+
14
+ def init_genai_client(api_key: Optional[str] = None):
15
+ global genai_client
16
+ try:
17
+ from google import genai
18
+ if api_key is None:
19
+ cfg = core_config.load_config()
20
+ api_key = cfg.get("GOOGLE_API_KEY")
21
+ genai_client = genai.Client(api_key=api_key) if api_key else None
22
+ except Exception:
23
+ genai_client = None
24
+ return genai_client
25
+
26
+
27
+ def init_embeddings(voyage_api_key: Optional[str] = None):
28
+ global embeddings
29
+ try:
30
+ from langchain_voyageai import VoyageAIEmbeddings
31
+ if voyage_api_key is None:
32
+ cfg = core_config.load_config()
33
+ voyage_api_key = cfg.get("VOYAGE_API_KEY")
34
+ if voyage_api_key:
35
+ os.environ.setdefault("VOYAGE_API_KEY", voyage_api_key)
36
+ embeddings = VoyageAIEmbeddings(model="voyage-3-large")
37
+ return embeddings
38
+ except Exception:
39
+ pass
40
+ return None
41
+
42
+
43
+ def init_qdrant_client(url: Optional[str] = None, api_key: Optional[str] = None):
44
+ global qdrant_client
45
+ try:
46
+ from qdrant_client import QdrantClient
47
+ if url is None or api_key is None:
48
+ cfg = core_config.load_config()
49
+ if url is None:
50
+ url = cfg.get("QDRANT_URL")
51
+ if api_key is None:
52
+ api_key = cfg.get("QDRANT_API_KEY")
53
+ if url:
54
+ qdrant_client = QdrantClient(url=url, api_key=api_key, prefer_grpc=False)
55
+ return qdrant_client
56
+ except Exception:
57
+ qdrant_client = None
58
+ return None
59
+
60
+
61
+ class ModelClient:
62
+ """Simple wrapper that exposes current clients as properties.
63
+
64
+ The module keeps module-level references (genai_client, embeddings, qdrant_client)
65
+ and this wrapper exposes them dynamically so other modules can import
66
+ `model_client` and access attributes like `model_client.genai_client`.
67
+ """
68
+
69
+ @property
70
+ def genai_client(self):
71
+ return genai_client
72
+
73
+ @property
74
+ def embeddings(self):
75
+ return embeddings
76
+
77
+ @property
78
+ def qdrant_client(self):
79
+ return qdrant_client
80
+
81
+ def init_all(self):
82
+ init_genai_client()
83
+ init_embeddings()
84
+ init_qdrant_client()
85
+
86
+
87
+ model_client = ModelClient()
app/services/pipeline_service.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full pipeline service adapted from the user's original script.
3
+
4
+ This module expects API clients to be available via `app.services.model_client`.
5
+ Per project configuration, this pipeline performs real model calls and Qdrant ingestion.
6
+ """
7
+ import os
8
+ import time
9
+ import random
10
+ import re
11
+ import gc
12
+ import queue
13
+ import threading
14
+ from typing import List, Optional, Callable, Any, Dict, Tuple
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
+ from threading import BoundedSemaphore, Event, Lock
17
+
18
+ from pdf2image import convert_from_path, pdfinfo_from_path
19
+ from pydantic import BaseModel, Field
20
+ from tqdm import tqdm
21
+
22
+ from ..schemas import models as schema_models
23
+ from . import model_client
24
+ from google import genai as _genai_module
25
+ from google.genai import types as genai_types
26
+ from .. import utils as app_utils
27
+ from . import qdrant_service
28
+ import tempfile
29
+ import logging
30
+
31
+ # note: don't capture clients at import time; use the factory `model_client` to get live instances
32
+
33
+ # ---------------------------
34
+ # Configuration (tuned for faster processing)
35
+ # ---------------------------
36
+ ROUTER_WORKERS = int(os.environ.get("ROUTER_WORKERS", 16))
37
+ SIMPLE_WORKERS = int(os.environ.get("SIMPLE_WORKERS", 12))
38
+ COMPLEX_WORKERS = int(os.environ.get("COMPLEX_WORKERS", 6))
39
+
40
+ FLASH_CONCURRENCY = SIMPLE_WORKERS
41
+ PRO_CONCURRENCY = COMPLEX_WORKERS
42
+
43
+ FLASH_MIN_INTERVAL = float(os.environ.get("FLASH_MIN_INTERVAL", 0.05))
44
+ PRO_MIN_INTERVAL = float(os.environ.get("PRO_MIN_INTERVAL", 0.20))
45
+
46
+ RETRY_ATTEMPTS = int(os.environ.get("RETRY_ATTEMPTS", 3))
47
+ # Circuit breaker tuning (env override)
48
+ CIRCUIT_THRESHOLD = int(os.environ.get("CIRCUIT_THRESHOLD", 8))
49
+ CIRCUIT_WINDOW = float(os.environ.get("CIRCUIT_WINDOW", 60.0))
50
+
51
+ # logger for this module
52
+ logger = logging.getLogger("pdf_extraction.pipeline")
53
+ if not logger.handlers:
54
+ ch = logging.StreamHandler()
55
+ ch.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
56
+ logger.addHandler(ch)
57
+ logger.setLevel(logging.INFO)
58
+
59
+ # Token-bucket rate limiter settings (tunable via env to match Colab expectations)
60
+ FLASH_RATE = float(os.environ.get("FLASH_RATE", 4.0)) # calls per second for flash family
61
+ PRO_RATE = float(os.environ.get("PRO_RATE", 1.0)) # calls per second for pro family
62
+
63
+
64
+ class TokenBucket:
65
+ """Simple thread-safe token bucket limiter.
66
+
67
+ - rate: tokens added per second
68
+ - capacity: maximum tokens
69
+ """
70
+ def __init__(self, rate: float, capacity: Optional[float] = None):
71
+ self.rate = float(rate)
72
+ self.capacity = float(capacity or rate)
73
+ self._tokens = self.capacity
74
+ self._last = time.time()
75
+ self._lock = threading.Lock()
76
+
77
+ def wait(self):
78
+ with self._lock:
79
+ now = time.time()
80
+ elapsed = now - self._last
81
+ # refill
82
+ self._tokens = min(self.capacity, self._tokens + elapsed * self.rate)
83
+ self._last = now
84
+ if self._tokens >= 1.0:
85
+ self._tokens -= 1.0
86
+ return
87
+ # need to wait until next token available
88
+ needed = (1.0 - self._tokens) / self.rate
89
+ # sleep outside the lock
90
+ time.sleep(needed)
91
+ with self._lock:
92
+ # after sleeping, consume one token (guard against races)
93
+ self._tokens = max(0.0, self._tokens - 1.0)
94
+
95
+
96
+ # instantiate global rate limiters
97
+ flash_rate_limiter = TokenBucket(FLASH_RATE, capacity=max(FLASH_RATE, 1.0))
98
+ pro_rate_limiter = TokenBucket(PRO_RATE, capacity=max(PRO_RATE, 1.0))
99
+
100
+ # Semaphores and locks
101
+ flash_sema = BoundedSemaphore(FLASH_CONCURRENCY)
102
+ pro_sema = BoundedSemaphore(PRO_CONCURRENCY)
103
+ flash_lock = Lock()
104
+ pro_lock = Lock()
105
+ _last_flash = 0.0
106
+ _last_pro = 0.0
107
+
108
+ # Simple in-memory circuit breaker: model_name -> (consecutive_failures, last_failure_time)
109
+ _circuit_breaker: Dict[str, Tuple[int, float]] = {}
110
+
111
+ def flash_wait():
112
+ global _last_flash
113
+ with flash_lock:
114
+ now = time.time()
115
+ delta = now - _last_flash
116
+ if delta < FLASH_MIN_INTERVAL:
117
+ time.sleep(FLASH_MIN_INTERVAL - delta)
118
+ _last_flash = time.time()
119
+
120
+ def pro_wait():
121
+ global _last_pro
122
+ with pro_lock:
123
+ now = time.time()
124
+ delta = now - _last_pro
125
+ if delta < PRO_MIN_INTERVAL:
126
+ time.sleep(PRO_MIN_INTERVAL - delta)
127
+ _last_pro = time.time()
128
+
129
+ # ---------------------------
130
+ # Category taxonomy (strict)
131
+ # ---------------------------
132
+ ALLOWED_COMPLEX_CATEGORIES = {
133
+ "Labeled Equipment Diagram",
134
+ "Exploded Parts Diagram",
135
+ "Technical Schematic",
136
+ "Flowchart",
137
+ "Process Diagram",
138
+ "Wiring Diagram",
139
+ "Choropleth Map",
140
+ "Geographic Reference Map",
141
+ "Infographic",
142
+ "Complex Table",
143
+ "Annotated Photograph",
144
+ "Safety Label Diagram",
145
+ }
146
+
147
+ # alias schema classes
148
+ RouterOutput = schema_models.RouterOutput
149
+ KeyComponent = schema_models.KeyComponent
150
+ DiagramExtraction = schema_models.DiagramExtraction
151
+ SimpleExtraction = schema_models.SimpleExtraction
152
+
153
+ # Prompts (kept as in user input)
154
+ ROUTER_PROMPT = r"""
155
+ SYSTEM:
156
+ You are the ROUTER AGENT that classifies ONE PAGE IMAGE into either 'complex' or 'simple' based only on visible visuals and visible text.
157
+ Do NOT guess. Use "[ILLEGIBLE]" for unreadable text.
158
+
159
+ OUTPUT EXACTLY one JSON object and NOTHING else with these fields:
160
+ {
161
+ "route": "complex" | "simple",
162
+ "contains_visual": true | false,
163
+ "visual_types": ["map","infographic","chart","diagram","complex_table","photo","logo","other"],
164
+ "reason": "<8-120 characters plain English>",
165
+ "confidence": 0.00
166
+ }
167
+
168
+ If confidence < 0.70 set "route" to "complex".
169
+ """
170
+
171
+ COMPLEX_PROMPT = r"""
172
+ SYSTEM:
173
+ You are a Technical Diagram & Visual Extraction Specialist.
174
+ You will be given ONE PAGE IMAGE (diagram, map, flowchart, complex table, infographic, or annotated photo).
175
+ Produce EXACTLY one JSON matching the schema below and NOTHING else.
176
+
177
+ Rules:
178
+ - Transcribe ALL visible labels, legend entries, axis ticks and captions verbatim. Use "[ILLEGIBLE]" for unreadable fragments.
179
+ - Do NOT invent values, units, or relationships not visible.
180
+ - Choose EXACTLY ONE category from the provided list (do not create new names).
181
+ - Provide 'printed_page' if a printed page number is visible on the page (e.g., 'PAGE 2' or '4'); otherwise use null or "[ILLEGIBLE]".
182
+ - Provide extraction_confidence 0.00–1.00 reflecting overall certainty.
183
+
184
+ ALLOWED CATEGORIES:
185
+ Labeled Equipment Diagram, Exploded Parts Diagram, Technical Schematic, Flowchart,
186
+ Process Diagram, Wiring Diagram, Choropleth Map, Geographic Reference Map, Infographic,
187
+ Complex Table, Annotated Photograph, Safety Label Diagram, Other
188
+
189
+ SCHEMA:
190
+ {
191
+ "schema_id":"diagram_v1",
192
+ "pdf_page": <integer - program will supply; model may also include printed_page string or [ILLEGIBLE]>,
193
+ "printed_page": "<string|null>",
194
+ "title": "<string>",
195
+ "category": "<one of the allowed categories>",
196
+ "summary": "<2-sentence factual summary>",
197
+ "key_components":[
198
+ {"name":"<label or [ILLEGIBLE]>","description":"<verbatim descriptor or short spatial hint>","extraction_confidence":0.00}
199
+ ],
200
+ "relationships":"<explicit relationships visible or '[NONE]'>",
201
+ "raw_text":"<all remaining visible text verbatim or [ILLEGIBLE]>",
202
+ "extraction_confidence": 0.00
203
+ }
204
+ """
205
+
206
+ SIMPLE_PROMPT = r"""
207
+ SYSTEM:
208
+ You are a Document Transcription Specialist. You will be given ONE PAGE IMAGE primarily containing readable text (paragraphs, headings, simple tables).
209
+ Produce EXACTLY one JSON matching the schema below and NOTHING else.
210
+
211
+ Rules:
212
+ - Transcribe text verbatim. Use "[ILLEGIBLE]" for unreadable fragments.
213
+ - Convert simple 1-row-per-record tables into Markdown tables.
214
+ - Provide 'printed_page' if visible; otherwise null or "[ILLEGIBLE]".
215
+ - Provide extraction_confidence 0.00–1.00.
216
+
217
+ SCHEMA:
218
+ {
219
+ "schema_id":"simple_v1",
220
+ "pdf_page": <integer - program will supply>,
221
+ "printed_page":"<string|null>",
222
+ "topic":"<string>",
223
+ "summary":"<2-sentence summary strictly from visible text>",
224
+ "content_markdown":"<full page transcribed into Markdown>",
225
+ "important_dates_or_entities":["<exact strings seen>"],
226
+ "extraction_confidence": 0.00
227
+ }
228
+ """
229
+
230
+
231
+ # Helpers: JSON substring extraction & pydantic-agnostic parse
232
+ def extract_json_substring(raw_text: str) -> str:
233
+ if not raw_text:
234
+ return raw_text
235
+ try:
236
+ start = raw_text.index("{")
237
+ end = raw_text.rfind("}")
238
+ if start >= 0 and end > start:
239
+ return raw_text[start:end+1]
240
+ except Exception:
241
+ pass
242
+ return raw_text
243
+
244
+
245
+ def parse_with_schema(schema_cls: Any, raw_json_str: str):
246
+ try:
247
+ parsed = schema_cls.model_validate_json(raw_json_str)
248
+ return parsed
249
+ except Exception:
250
+ try:
251
+ parsed = schema_cls.parse_raw(raw_json_str)
252
+ return parsed
253
+ except Exception as e:
254
+ raise e
255
+
256
+
257
+ # Safe API call with backoff & rate shaping
258
+ def safe_generate_content(model_name: str, contents: list, config_obj: Any = None, is_flash: bool = False, is_pro: bool = False):
259
+ """Make a model call with retries, spacing, semaphores and provider-aware backoff.
260
+
261
+ This function attempts to parse provider RetryInfo / Retry-After hints from
262
+ the exception (when available) and prefers that delay over the local
263
+ exponential backoff. It still records failures for the circuit-breaker.
264
+ """
265
+
266
+ def _parse_retry_after(exc: Exception) -> Optional[float]:
267
+ """Extract seconds from common Retry-After / RetryInfo patterns in exceptions."""
268
+ # 1) Try to read common response-like attributes
269
+ resp = getattr(exc, "response", None) or getattr(exc, "http_response", None)
270
+ if resp is not None:
271
+ headers = getattr(resp, "headers", None) or getattr(resp, "header", None)
272
+ if headers and isinstance(headers, dict):
273
+ ra = headers.get("Retry-After") or headers.get("retry-after")
274
+ if ra:
275
+ try:
276
+ return float(ra)
277
+ except Exception:
278
+ pass
279
+ # 2) Parse textual RetryInfo (e.g. "retryDelay": "9s") from str(exc)
280
+ s = str(exc)
281
+ m = re.search(r"retryDelay[\"']?\s*[:=]\s*[\"']?(\d+(?:\.\d+)?)s", s, flags=re.IGNORECASE)
282
+ if m:
283
+ try:
284
+ return float(m.group(1))
285
+ except Exception:
286
+ pass
287
+ m2 = re.search(r"Retry-After\s*[:=]?\s*(\d+(?:\.\d+)?)(?:s|\s|$)", s, flags=re.IGNORECASE)
288
+ if m2:
289
+ try:
290
+ return float(m2.group(1))
291
+ except Exception:
292
+ pass
293
+ return None
294
+
295
+ # quick fail if client not configured to avoid poisoning circuit-breaker
296
+ if model_client.genai_client is None:
297
+ raise RuntimeError("GenAI client not configured. Ensure GOOGLE_API_KEY is set and the app was restarted.")
298
+
299
+ base_delay = 0.5
300
+ for attempt in range(1, RETRY_ATTEMPTS + 1):
301
+ try:
302
+ # simple circuit breaker per-model
303
+ info = _circuit_breaker.get(model_name)
304
+ if info:
305
+ failures, last_time = info
306
+ if failures >= CIRCUIT_THRESHOLD and (time.time() - last_time) < CIRCUIT_WINDOW:
307
+ logger.warning("Circuit open for %s (failures=%s, last=%s)", model_name, failures, last_time)
308
+ raise RuntimeError(f"Circuit open for {model_name}")
309
+ if is_flash:
310
+ flash_wait()
311
+ with flash_sema:
312
+ resp = model_client.genai_client.models.generate_content(model=model_name, contents=contents, config=config_obj)
313
+ elif is_pro:
314
+ pro_wait()
315
+ with pro_sema:
316
+ resp = model_client.genai_client.models.generate_content(model=model_name, contents=contents, config=config_obj)
317
+ else:
318
+ resp = model_client.genai_client.models.generate_content(model=model_name, contents=contents, config=config_obj)
319
+ # success -> reset circuit breaker for this model
320
+ if model_name in _circuit_breaker:
321
+ _circuit_breaker.pop(model_name, None)
322
+ logger.info("Circuit breaker reset for %s after successful call", model_name)
323
+ return resp
324
+ except Exception as e:
325
+ # record failure
326
+ failures, last_time = _circuit_breaker.get(model_name, (0, 0.0))
327
+ failures += 1
328
+ _circuit_breaker[model_name] = (failures, time.time())
329
+ logger.warning("Model %s failure recorded (count=%s): %s", model_name, failures, e)
330
+
331
+ # Try to honor provider's Retry-After / RetryInfo if present
332
+ retry_seconds = _parse_retry_after(e)
333
+ s = str(e).lower()
334
+ if any(k in s for k in ("429", "rate", "quota", "resource exhausted")):
335
+ # compute backoff: prefer provider-specified retry, otherwise exponential
336
+ if retry_seconds and retry_seconds > 0:
337
+ wait = max(retry_seconds, 0.5)
338
+ else:
339
+ wait = base_delay * (2 ** (attempt - 1)) + random.uniform(0.05, 0.3)
340
+ if attempt < RETRY_ATTEMPTS:
341
+ logger.warning("%s rate-limited. Attempt %s/%s - sleeping %.2fs (provider_retry=%s)", model_name, attempt, RETRY_ATTEMPTS, wait, retry_seconds)
342
+ time.sleep(wait)
343
+ continue
344
+ # transient server/connection errors
345
+ if attempt < RETRY_ATTEMPTS:
346
+ wait = 0.3 + random.uniform(0, 0.5)
347
+ logger.warning("%s transient error. Retry %s/%s after %.2fs... Error: %s", model_name, attempt, RETRY_ATTEMPTS, wait, e)
348
+ time.sleep(wait)
349
+ continue
350
+ raise
351
+
352
+
353
+ # validate_and_retry wrapper
354
+ def validate_and_retry(call_fn: Callable[[], Any], schema_cls: Any, page_index: int, min_confidence: float = 0.60, max_attempts: int = 3) -> (dict, str):
355
+ last_raw = None
356
+ for attempt in range(1, max_attempts + 1):
357
+ resp = call_fn()
358
+ raw = getattr(resp, "text", None) or str(resp)
359
+ last_raw = raw
360
+ candidate = extract_json_substring(raw)
361
+ try:
362
+ parsed_obj = parse_with_schema(schema_cls, candidate)
363
+ data = parsed_obj.model_dump() if hasattr(parsed_obj, "model_dump") else parsed_obj.dict()
364
+ data["pdf_page"] = page_index + 1
365
+ conf = data.get("extraction_confidence") or data.get("confidence")
366
+ if conf is None:
367
+ return data, raw
368
+ try:
369
+ conf = float(conf)
370
+ except:
371
+ conf = 0.0
372
+ if schema_cls is DiagramExtraction:
373
+ cat = data.get("category", "")
374
+ if cat not in ALLOWED_COMPLEX_CATEGORIES:
375
+ data["category"] = "Other"
376
+ if "printed_page" in data:
377
+ if not data["printed_page"] or data["printed_page"] == "[ILLEGIBLE]":
378
+ data["printed_page"] = None
379
+ if conf < min_confidence:
380
+ if attempt < max_attempts:
381
+ time.sleep(0.2 * attempt + random.uniform(0.02, 0.1))
382
+ continue
383
+ else:
384
+ return data, raw
385
+ if "summary" in data and isinstance(data["summary"], str):
386
+ if len(data["summary"].strip()) < 20 and attempt < max_attempts:
387
+ time.sleep(0.15 + random.uniform(0, 0.1))
388
+ continue
389
+ return data, raw
390
+ except Exception as e:
391
+ print(f" WARNING: parsing failed for page {page_index+1} attempt {attempt}. Error: {e}")
392
+ if attempt < max_attempts:
393
+ time.sleep(0.3 * attempt + random.uniform(0.02, 0.2))
394
+ continue
395
+ raw_excerpt = (last_raw or "")[:1000]
396
+ raise RuntimeError(f"Parsing/validation failed after {max_attempts} attempts for page {page_index+1}. Raw excerpt (first 1000 chars):\n{raw_excerpt}\nError: {e}")
397
+
398
+
399
+ # Markdown normalization
400
+ def normalize_markdown(md: str) -> str:
401
+ lines = md.splitlines()
402
+ normalized = []
403
+ prev = None
404
+ for line in lines:
405
+ s = line.strip()
406
+ if not s:
407
+ normalized.append("")
408
+ prev = ""
409
+ continue
410
+ if re.match(r'^[A-Z0-9][A-Z0-9 \-\/\(\)\.]{3,}$', s) and sum(1 for c in s if c.isalpha()) >= 3:
411
+ def smart_title(text):
412
+ parts = text.split()
413
+ out = []
414
+ for w in parts:
415
+ if w.isupper() and len(w) <= 4:
416
+ out.append(w)
417
+ else:
418
+ out.append(w.capitalize())
419
+ return " ".join(out)
420
+ normalized.append("## " + smart_title(s))
421
+ elif s.endswith(":") and len(s) < 80:
422
+ normalized.append("### " + s.rstrip(":"))
423
+ else:
424
+ normalized.append(line)
425
+ prev = s
426
+ return "\n".join(normalized)
427
+
428
+
429
+ # Worker functions (using genai client)
430
+ def get_image(pdf_path: str, page_index: int):
431
+ try:
432
+ images = convert_from_path(pdf_path, first_page=page_index+1, last_page=page_index+1, fmt="jpeg")
433
+ return images[0] if images else None
434
+ except Exception:
435
+ return None
436
+
437
+
438
+ def router_worker(pdf_path: str, page_index: int) -> Dict:
439
+ img = get_image(pdf_path, page_index)
440
+ result = {"page_index": page_index, "route": "complex", "raw": None}
441
+ if img is None:
442
+ return result
443
+ def call():
444
+ cfg = {
445
+ "response_mime_type": "application/json",
446
+ "response_json_schema": RouterOutput.model_json_schema(),
447
+ "temperature": 0.0
448
+ }
449
+ return safe_generate_content(model_name="gemini-2.0-flash", contents=[img, ROUTER_PROMPT], config_obj=cfg, is_flash=True)
450
+ try:
451
+ resp = call()
452
+ raw = getattr(resp, "text", None) or str(resp)
453
+ result["raw"] = raw
454
+ try:
455
+ parsed = parse_with_schema(RouterOutput, extract_json_substring(raw))
456
+ out = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
457
+ except Exception:
458
+ out = {"route": "complex", "contains_visual": True, "visual_types": ["other"], "reason": "parse_failed", "confidence": 0.0}
459
+ result["route"] = out.get("route", "complex")
460
+ return result
461
+ finally:
462
+ try: img.close()
463
+ except: pass
464
+ gc.collect()
465
+
466
+
467
+ def simple_worker(pdf_path: str, page_index: int) -> Dict:
468
+ img = get_image(pdf_path, page_index)
469
+ out = {"page_index": page_index, "type": "SIMPLE", "data": None, "error": None, "raw": None}
470
+ if img is None:
471
+ out["error"] = "image_load_failed"
472
+ return out
473
+ def call():
474
+ cfg = genai_types.GenerateContentConfig(
475
+ response_mime_type="application/json",
476
+ response_json_schema=SimpleExtraction.model_json_schema(),
477
+ thinking_config=genai_types.ThinkingConfig(thinking_budget=0),
478
+ media_resolution="media_resolution_medium",
479
+ temperature=0.0
480
+ )
481
+ return safe_generate_content(model_name="gemini-2.5-flash-preview-09-2025", contents=[img, SIMPLE_PROMPT], config_obj=cfg, is_flash=True)
482
+ try:
483
+ data, raw = validate_and_retry(call, SimpleExtraction, page_index, min_confidence=0.6, max_attempts=RETRY_ATTEMPTS)
484
+ data["pdf_page"] = page_index + 1
485
+ if "printed_page" not in data or data.get("printed_page") in ("", "[ILLEGIBLE]"):
486
+ data["printed_page"] = None
487
+ if "content_markdown" in data and isinstance(data["content_markdown"], str):
488
+ data["content_markdown"] = normalize_markdown(data["content_markdown"])
489
+ out["data"] = data
490
+ out["raw"] = raw
491
+ return out
492
+ except Exception as e:
493
+ out["error"] = str(e)
494
+ out["raw"] = None
495
+ return out
496
+ finally:
497
+ try: img.close()
498
+ except: pass
499
+ gc.collect()
500
+
501
+
502
+ def complex_worker(pdf_path: str, page_index: int) -> Dict:
503
+ img = get_image(pdf_path, page_index)
504
+ out = {"page_index": page_index, "type": "COMPLEX", "data": None, "error": None, "raw": None}
505
+ if img is None:
506
+ out["error"] = "image_load_failed"
507
+ return out
508
+ def call():
509
+ cfg = genai_types.GenerateContentConfig(
510
+ response_mime_type="application/json",
511
+ response_json_schema=DiagramExtraction.model_json_schema(),
512
+ thinking_config=genai_types.ThinkingConfig(thinking_level="low"),
513
+ media_resolution="media_resolution_high",
514
+ temperature=0.0
515
+ )
516
+ return safe_generate_content(model_name="gemini-3-pro-preview", contents=[img, COMPLEX_PROMPT], config_obj=cfg, is_pro=True)
517
+ try:
518
+ data, raw = validate_and_retry(call, DiagramExtraction, page_index, min_confidence=0.6, max_attempts=RETRY_ATTEMPTS)
519
+ if data.get("category") not in ALLOWED_COMPLEX_CATEGORIES:
520
+ data["category"] = "Other"
521
+ data["pdf_page"] = page_index + 1
522
+ if "printed_page" not in data or data.get("printed_page") in ("", "[ILLEGIBLE]"):
523
+ data["printed_page"] = None
524
+ out["data"] = data
525
+ out["raw"] = raw
526
+ return out
527
+ except Exception as e:
528
+ out["error"] = str(e)
529
+ return out
530
+ finally:
531
+ try: img.close()
532
+ except: pass
533
+ gc.collect()
534
+
535
+
536
+ # Producer / Consumer (streaming)
537
+ simple_queue = queue.Queue()
538
+ complex_queue = queue.Queue()
539
+ router_finished = Event()
540
+
541
+
542
+ def router_producer(pdf_path: str, total_pages: int):
543
+ print(" [Router] Scanning pages and routing...")
544
+ with ThreadPoolExecutor(max_workers=ROUTER_WORKERS) as ex:
545
+ futures = {ex.submit(router_worker, pdf_path, i): i for i in range(total_pages)}
546
+ for fut in as_completed(futures):
547
+ res = fut.result()
548
+ idx = res["page_index"]
549
+ route = res.get("route", "complex")
550
+ if route == "complex":
551
+ complex_queue.put(idx)
552
+ else:
553
+ simple_queue.put(idx)
554
+ print(" [Router] Done.")
555
+ router_finished.set()
556
+
557
+
558
+ def consumer_processor(pdf_path: str, results: list):
559
+ print(" [Consumer] Starting workers...")
560
+ with ThreadPoolExecutor(max_workers=SIMPLE_WORKERS + COMPLEX_WORKERS) as ex:
561
+ futures = []
562
+ while True:
563
+ if router_finished.is_set() and simple_queue.empty() and complex_queue.empty():
564
+ break
565
+ while not simple_queue.empty():
566
+ idx = simple_queue.get_nowait()
567
+ futures.append(ex.submit(simple_worker, pdf_path, idx))
568
+ while not complex_queue.empty():
569
+ idx = complex_queue.get_nowait()
570
+ futures.append(ex.submit(complex_worker, pdf_path, idx))
571
+ time.sleep(0.03)
572
+ for fut in tqdm(as_completed(futures), total=len(futures), unit="page"):
573
+ try:
574
+ r = fut.result()
575
+ except Exception as e:
576
+ r = {"page_index": None, "type": "FAILED", "data": None, "error": str(e)}
577
+ results.append(r)
578
+ print(" [Consumer] All tasks finished.")
579
+
580
+
581
+ def save_results(results: List[dict], out_md: str = "final_report.md"):
582
+ results_sorted = sorted([r for r in results if r.get("page_index") is not None], key=lambda x: x["page_index"])
583
+ with open(out_md, "w", encoding="utf-8") as f:
584
+ f.write("# Extraction Report\n\n")
585
+ f.write(f"**Generated:** {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}\n\n")
586
+ f.write("---\n\n")
587
+ f.write("## Table of Contents\n\n")
588
+ for r in results_sorted:
589
+ p = r["page_index"] + 1
590
+ typ = r.get("type", "UNKNOWN")
591
+ title = ""
592
+ if r.get("data") and isinstance(r["data"], dict):
593
+ title = r["data"].get("title") or r["data"].get("topic") or ""
594
+ f.write(f"- [{typ} — Page {p}]{' — ' + title if title else ''}\n")
595
+ f.write("\n---\n\n")
596
+ for r in results_sorted:
597
+ p = r["page_index"] + 1
598
+ typ = r.get("type", "UNKNOWN")
599
+ f.write(f"## Page {p} — {typ}\n\n")
600
+ f.write(f"- **PDF page index:** {r['page_index']+1}\n")
601
+ if r.get("data"):
602
+ data = r["data"]
603
+ printed = data.get("printed_page")
604
+ confidence = data.get("extraction_confidence", data.get("confidence", None))
605
+ f.write(f"- **Printed page:** {printed if printed else 'N/A'}\n")
606
+ f.write(f"- **Extraction confidence:** {confidence if confidence is not None else 'N/A'}\n\n")
607
+ if typ == "COMPLEX" or data.get("schema_id") == "diagram_v1":
608
+ f.write(f"### Title\n\n{data.get('title','(no title)')}\n\n")
609
+ f.write(f"### Category\n\n{data.get('category','Other')}\n\n")
610
+ f.write(f"### Summary\n\n{data.get('summary','(no summary)')}\n\n")
611
+ if data.get("key_components"):
612
+ f.write("### Key Components\n\n")
613
+ for comp in data.get("key_components", []):
614
+ name = comp.get("name", "(no name)")
615
+ desc = comp.get("description", "")
616
+ conf = comp.get("extraction_confidence", None)
617
+ f.write(f"- **{name}** — {desc}" + (f" (confidence: {conf})" if conf is not None else "") + "\n")
618
+ f.write("\n")
619
+ f.write("### Relationships / Notes\n\n")
620
+ f.write(f"{data.get('relationships','[NONE]')}\n\n")
621
+ if data.get("raw_text"):
622
+ f.write("### Raw Text (verbatim)\n\n")
623
+ f.write("> " + "\n> ".join(str(data.get("raw_text","")).splitlines()) + "\n\n")
624
+ else:
625
+ f.write("### Raw Text (verbatim)\n\nN/A\n\n")
626
+ elif typ == "SIMPLE" or data.get("schema_id") == "simple_v1":
627
+ f.write(f"### Topic\n\n{data.get('topic','(no topic)')}\n\n")
628
+ f.write(f"### Summary\n\n{data.get('summary','(no summary)')}\n\n")
629
+ f.write("### Content\n\n")
630
+ content_md = data.get("content_markdown", "")
631
+ if content_md:
632
+ f.write(content_md + "\n\n")
633
+ else:
634
+ f.write("(no content)\n\n")
635
+ if data.get("important_dates_or_entities"):
636
+ f.write("### Important Dates / Entities\n\n")
637
+ for ent in data.get("important_dates_or_entities", []):
638
+ f.write(f"- {ent}\n")
639
+ f.write("\n")
640
+ else:
641
+ f.write("### Important Dates / Entities\n\nN/A\n\n")
642
+ else:
643
+ f.write("### Extracted Fields\n\n")
644
+ for k, v in data.items():
645
+ if k in ("content_markdown", "raw_text"):
646
+ continue
647
+ f.write(f"- **{k}**: {v}\n")
648
+ f.write("\n")
649
+ if data.get("content_markdown"):
650
+ f.write("### Content\n\n")
651
+ f.write(data.get("content_markdown") + "\n\n")
652
+ else:
653
+ f.write("### Extraction failed or returned no data\n\n")
654
+ f.write(f"**Error:** {r.get('error')}\n\n")
655
+ f.write("\n---\n\n")
656
+ print(f"Saved Markdown: {os.path.abspath(out_md)}")
657
+ print("Note: raw model outputs are not saved to disk by design.")
658
+ return os.path.abspath(out_md)
659
+
660
+
661
+ def run_pipeline(
662
+ pdf_path: str,
663
+ max_pages: Optional[int] = None,
664
+ out_md: Optional[str] = None,
665
+ progress_hook: Optional[Callable[[dict], None]] = None,
666
+ doc_id: Optional[str] = None,
667
+ original_filename: Optional[str] = None,
668
+ ):
669
+ if not os.path.exists(pdf_path):
670
+ raise FileNotFoundError(f"{pdf_path} not found")
671
+ info = pdfinfo_from_path(pdf_path)
672
+ total_pages = info.get("Pages", 0)
673
+ if max_pages is None:
674
+ pages_to_process = total_pages
675
+ else:
676
+ pages_to_process = min(max_pages, total_pages)
677
+ print(f"Processing {pages_to_process}/{total_pages} pages from {os.path.basename(pdf_path)}")
678
+ results = []
679
+
680
+ # start producer
681
+
682
+ if progress_hook:
683
+ progress_hook({"event": "started", "pages_total": pages_to_process, "pdf": os.path.basename(pdf_path)})
684
+
685
+ # start producer
686
+ t = threading.Thread(target=router_producer, args=(pdf_path, pages_to_process))
687
+ t.start()
688
+ # run consumer in main thread (blocks)
689
+ consumer_processor(pdf_path, results)
690
+ t.join()
691
+
692
+ # save markdown report to temp file if not provided
693
+ if out_md is None:
694
+ fd, tmp_md = tempfile.mkstemp(prefix="report_", suffix=".md", dir=app_utils.DATA_DIR)
695
+ os.close(fd)
696
+ out_md = tmp_md
697
+
698
+ report_path = save_results(results, out_md=out_md)
699
+ pages_processed = len([r for r in results if r.get("page_index") is not None])
700
+
701
+ if progress_hook:
702
+ progress_hook({"event": "report_saved", "report_path": report_path, "pages_processed": pages_processed})
703
+
704
+ # Chunk and ingest into Qdrant
705
+ try:
706
+ if progress_hook:
707
+ progress_hook({"event": "chunking_started", "report_path": report_path})
708
+ chunks = qdrant_service.chunk_markdown_by_page(report_path)
709
+ if progress_hook:
710
+ progress_hook({"event": "chunking_finished", "chunks": len(chunks)})
711
+
712
+ if progress_hook:
713
+ progress_hook({"event": "ingest_started", "collection": os.environ.get("QDRANT_COLLECTION", "manual_pages")})
714
+
715
+ # determine batch size from env (default 256)
716
+ try:
717
+ batch_size = int(os.environ.get("QDRANT_BATCH_SIZE", 256))
718
+ except Exception:
719
+ batch_size = 256
720
+ ingest_res = qdrant_service.ingest_chunks_into_qdrant(
721
+ chunks,
722
+ collection_name=os.environ.get("QDRANT_COLLECTION", "manual_pages"),
723
+ batch_size=batch_size,
724
+ progress_hook=progress_hook,
725
+ )
726
+
727
+ if progress_hook:
728
+ progress_hook({"event": "ingest_finished", "result": ingest_res})
729
+
730
+ # if successful ingestion, persist metadata and cleanup
731
+ if isinstance(ingest_res, dict) and ingest_res.get("ingested"):
732
+ # append metadata entry if doc_id provided
733
+ try:
734
+ if doc_id or original_filename:
735
+ entry = {"uuid": doc_id or "", "original_filename": original_filename or os.path.basename(pdf_path), "report": report_path, "created_at": time.time()}
736
+ app_utils.append_metadata_entry(entry)
737
+ except Exception as e:
738
+ print(f"Warning: failed to append metadata: {e}")
739
+ # remove temp files
740
+ try:
741
+ if os.path.exists(pdf_path):
742
+ os.remove(pdf_path)
743
+ if os.path.exists(report_path):
744
+ os.remove(report_path)
745
+ except Exception as e:
746
+ print(f"Warning: failed to remove temp files: {e}")
747
+
748
+ except Exception as e:
749
+ if progress_hook:
750
+ progress_hook({"event": "error", "error": str(e)})
751
+ raise
752
+
753
+ if progress_hook:
754
+ progress_hook({"event": "completed", "pages_processed": pages_processed, "ingest_result": ingest_res})
755
+
756
+ return {"report_path": report_path, "pages_processed": pages_processed, "results": results, "ingest": ingest_res}
app/services/qdrant_service.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ import random
4
+ from typing import List, Dict, Optional, Callable
5
+ import os
6
+ from .model_client import init_qdrant_client, init_embeddings
7
+
8
+ # chunking regex
9
+ PAGE_SPLIT_RE = re.compile(r'(?m)^(##\s+Page\s+\d+.*)$')
10
+
11
+
12
+ def chunk_markdown_by_page(md_path: str) -> List[Dict]:
13
+ with open(md_path, 'r', encoding='utf-8') as f:
14
+ md = f.read()
15
+ parts = PAGE_SPLIT_RE.split(md)
16
+ chunks = []
17
+ preamble = parts[0].strip()
18
+ if preamble:
19
+ chunks.append({
20
+ 'id': 'page_0', 'page': 0, 'page_type': None, 'text': preamble, 'char_length': len(preamble)
21
+ })
22
+ i = 1
23
+ while i < len(parts):
24
+ header = parts[i].strip()
25
+ body = parts[i+1].strip() if i+1 < len(parts) else ''
26
+ m = re.search(r'Page\s+(\d+)', header)
27
+ page_num = int(m.group(1)) if m else None
28
+ page_type = None
29
+ if 'SIMPLE' in header.upper():
30
+ page_type = 'SIMPLE'
31
+ elif 'COMPLEX' in header.upper():
32
+ page_type = 'COMPLEX'
33
+ full_text = f"{header}\n\n{body}".strip()
34
+ chunks.append({'id': f'page_{page_num}', 'page': page_num, 'page_type': page_type, 'text': full_text, 'char_length': len(full_text)})
35
+ i += 2
36
+ return chunks
37
+
38
+
39
+ def ingest_chunks_into_qdrant(
40
+ chunks: List[Dict],
41
+ collection_name: str = 'manual_pages',
42
+ batch_size: int = 256,
43
+ progress_hook: Optional[Callable[[dict], None]] = None,
44
+ retry_attempts: int = 3,
45
+ ) -> Dict:
46
+ """Ingest chunks into Qdrant using Voyage embeddings and langchain vector store.
47
+ Implements chunked/batched upserts. Calls `progress_hook` after each batch when provided.
48
+ This function performs real ingestion and does not support dry-run.
49
+ """
50
+
51
+ qc = init_qdrant_client()
52
+ emb = init_embeddings()
53
+ if qc is None or emb is None:
54
+ return {'error': 'qdrant-or-embeddings-missing'}
55
+
56
+ try:
57
+ # lazy import heavy libs
58
+ from langchain_qdrant import QdrantVectorStore
59
+ from langchain_core.documents import Document
60
+
61
+ # compute vector size by embedding a small sample
62
+ try:
63
+ sample_vec = emb.embed_query('sample size')
64
+ vector_size = len(sample_vec)
65
+ except Exception:
66
+ vector_size = None
67
+
68
+ # create collection if not exists
69
+ try:
70
+ existing = [c.name for c in qc.get_collections().collections]
71
+ except Exception:
72
+ existing = []
73
+ if collection_name not in existing and vector_size is not None:
74
+ from qdrant_client.models import VectorParams, Distance
75
+ qc.create_collection(collection_name=collection_name, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE))
76
+
77
+ # build documents (skip page 0 preamble)
78
+ docs_all = []
79
+ for c in chunks:
80
+ if c.get('page') in (None, 0):
81
+ continue
82
+ docs_all.append(Document(page_content=c['text'], metadata={'chunk_id': c['id'], 'page': c['page'], 'page_type': c.get('page_type'), 'char_length': c.get('char_length')}))
83
+
84
+ total_docs = len(docs_all)
85
+ if total_docs == 0:
86
+ return {'ingested': 0, 'collection': collection_name}
87
+
88
+ store = QdrantVectorStore(client=qc, collection_name=collection_name, embedding=emb)
89
+
90
+ # helper to get embeddings for a list of texts with fallback
91
+ def embed_texts(texts: List[str]) -> List[List[float]]:
92
+ # try common batch method names used by embedding wrappers
93
+ if hasattr(emb, 'embed_documents'):
94
+ return emb.embed_documents(texts)
95
+ if hasattr(emb, 'embed_texts'):
96
+ return emb.embed_texts(texts)
97
+ # fallback to per-item embedding
98
+ return [emb.embed_query(t) for t in texts]
99
+
100
+ ingested = 0
101
+ # process in batches
102
+ for i in range(0, total_docs, batch_size):
103
+ batch_docs = docs_all[i:i+batch_size]
104
+ texts = [d.page_content for d in batch_docs]
105
+
106
+ # get embeddings with retries
107
+ last_err = None
108
+ for attempt in range(1, retry_attempts + 1):
109
+ try:
110
+ vectors = embed_texts(texts)
111
+ break
112
+ except Exception as e:
113
+ last_err = e
114
+ if attempt < retry_attempts:
115
+ time.sleep(0.5 * attempt + random.uniform(0, 0.2))
116
+ continue
117
+ raise
118
+
119
+ # attach vectors to documents via metadata (QdrantVectorStore will compute embeddings again if not provided),
120
+ # but many vector stores accept raw embeddings; we can use store.client upsert directly if needed.
121
+ # We'll attempt store.add_documents(batch) and fall back to per-doc add if necessary.
122
+ success = False
123
+ for attempt in range(1, retry_attempts + 1):
124
+ try:
125
+ # The high-level API will call embedding again unless we upsert directly; it's acceptable for now.
126
+ store.add_documents(batch_docs)
127
+ success = True
128
+ break
129
+ except Exception as e:
130
+ last_err = e
131
+ if attempt < retry_attempts:
132
+ time.sleep(0.4 * attempt + random.uniform(0, 0.2))
133
+ continue
134
+ raise
135
+
136
+ if not success:
137
+ raise RuntimeError(f"Failed to ingest batch starting at {i}: {last_err}")
138
+
139
+ ingested += len(batch_docs)
140
+ # emit progress
141
+ if progress_hook:
142
+ progress_hook({
143
+ 'event': 'ingest_batch',
144
+ 'batch_index': i // batch_size,
145
+ 'batch_size': len(batch_docs),
146
+ 'total_docs': total_docs,
147
+ 'ingested_so_far': ingested,
148
+ })
149
+
150
+ return {'ingested': ingested, 'collection': collection_name}
151
+ except Exception as e:
152
+ return {'error': str(e)}
app/utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from typing import Tuple
4
+ import json
5
+ import threading
6
+
7
+ # metadata storage for lightweight JSON DB
8
+ DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
9
+ os.makedirs(DATA_DIR, exist_ok=True)
10
+ METADATA_PATH = os.path.join(DATA_DIR, "metadata.json")
11
+ _metadata_lock = threading.Lock()
12
+
13
+
14
+ def save_upload_file_tmp(upload_file) -> Tuple[str, str]:
15
+ """Save a FastAPI UploadFile to a temporary file and return (tmp_path, filename)."""
16
+ suffix = os.path.splitext(upload_file.filename)[1]
17
+ fd, tmp_path = tempfile.mkstemp(suffix=suffix)
18
+ with os.fdopen(fd, "wb") as out:
19
+ content = upload_file.file.read()
20
+ out.write(content)
21
+ return tmp_path, upload_file.filename
22
+
23
+
24
+ def append_metadata_entry(entry: dict):
25
+ """Append an entry to the metadata JSON file (list of entries). Thread-safe."""
26
+ with _metadata_lock:
27
+ data = []
28
+ if os.path.exists(METADATA_PATH):
29
+ try:
30
+ with open(METADATA_PATH, "r", encoding="utf-8") as f:
31
+ data = json.load(f)
32
+ except Exception:
33
+ data = []
34
+ data.append(entry)
35
+ with open(METADATA_PATH, "w", encoding="utf-8") as f:
36
+ json.dump(data, f, ensure_ascii=False, indent=2)
37
+
38
+
39
+ def read_metadata() -> list:
40
+ with _metadata_lock:
41
+ if not os.path.exists(METADATA_PATH):
42
+ return []
43
+ try:
44
+ with open(METADATA_PATH, "r", encoding="utf-8") as f:
45
+ return json.load(f)
46
+ except Exception:
47
+ return []
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pydantic
4
+ pdf2image
5
+ qdrant-client
6
+ langchain-voyageai
7
+ google-genai
8
+ langchain
9
+ python-multipart
10
+ python-dotenv
11
+ langchain_qdrant
scripts/check_genai_key.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Quick script to validate Google GenAI API key and model access.
3
+
4
+ Run from repository root (inside venv) like:
5
+
6
+ python3 scripts/check_genai_key.py
7
+
8
+ It will attempt to initialize the genai client using the same code in
9
+ `app.services.model_client` and make a lightweight request to validate the
10
+ key. The script prints a clear message for success / failure and includes
11
+ any provider error details.
12
+ """
13
+ import importlib
14
+ import traceback
15
+ import sys
16
+
17
+ print("Checking GenAI API key and model access...")
18
+
19
+ try:
20
+ mc = importlib.import_module("app.services.model_client")
21
+ except Exception as e:
22
+ print("Failed to import app.services.model_client:", e)
23
+ traceback.print_exc()
24
+ sys.exit(2)
25
+
26
+ # Try to initialize client (reads env var GOOGLE_API_KEY)
27
+ client = None
28
+ try:
29
+ c = mc.init_genai_client()
30
+ # the init_genai_client returns the client and also assigns mc.genai_client
31
+ client = c or getattr(mc, "genai_client", None)
32
+ except Exception as e:
33
+ print("init_genai_client raised exception:", e)
34
+ traceback.print_exc()
35
+ sys.exit(2)
36
+
37
+ if not client:
38
+ print("No GenAI client configured. Please set GOOGLE_API_KEY in environment.")
39
+ sys.exit(1)
40
+
41
+ print("GenAI client created. Attempting a lightweight API call to verify key and model access...")
42
+
43
+ # Try to call a safe method. We attempt to use models.list() or models.get()
44
+ # if available, otherwise fall back to a small generate_content call.
45
+ try:
46
+ models_api = getattr(client, "models", None)
47
+ if models_api is None:
48
+ print("Client has no .models attribute; cannot proceed.")
49
+ sys.exit(3)
50
+
51
+ # Prefer listing models if available
52
+ if hasattr(models_api, "list"):
53
+ try:
54
+ res = models_api.list()
55
+ print("Models list call succeeded. Sample output:")
56
+ print(res)
57
+ sys.exit(0)
58
+ except Exception as e:
59
+ print("models.list() failed (continuing to try other checks):", e)
60
+
61
+ if hasattr(models_api, "get"):
62
+ try:
63
+ # Try to fetch a commonly available model
64
+ model_name = "gemini-2.0-flash"
65
+ res = models_api.get(model=model_name)
66
+ print(f"models.get('{model_name}') succeeded:")
67
+ print(res)
68
+ sys.exit(0)
69
+ except Exception as e:
70
+ print("models.get() failed (continuing to try generate_content):", e)
71
+
72
+ # Fallback: small generate_content call (may consume quota)
73
+ # Use a very small prompt
74
+ try:
75
+ prompt = "Ping"
76
+ print("Calling models.generate_content with a tiny prompt (may hit quota)...")
77
+ resp = models_api.generate_content(model="gemini-2.0-flash", contents=[{"type": "text", "text": prompt}])
78
+ text = getattr(resp, "text", None) or str(resp)
79
+ print("generate_content succeeded, response preview:")
80
+ print(text[:1000])
81
+ sys.exit(0)
82
+ except Exception as e:
83
+ print("generate_content failed:", e)
84
+ traceback.print_exc()
85
+ # Inspect exception for structured error info
86
+ try:
87
+ err_str = str(e)
88
+ print("Exception text:\n", err_str)
89
+ except Exception:
90
+ pass
91
+ sys.exit(3)
92
+
93
+ except Exception as e:
94
+ print("Unexpected error while validating key:", e)
95
+ traceback.print_exc()
96
+ sys.exit(2)