#!/usr/bin/env python3 """ Visual RAG Toolkit CLI Provides command-line interface for: - Processing PDFs (embedding, Cloudinary upload, Qdrant indexing) - Searching documents - Managing collections Usage: # Process PDFs (like process_pdfs_saliency_v2.py) visual-rag process --reports-dir ./pdfs --metadata-file metadata.json # Search visual-rag search --query "budget allocation" --collection my_docs # Show collection info visual-rag info --collection my_docs """ import argparse import logging import os import sys from pathlib import Path from urllib.parse import urlparse from dotenv import load_dotenv logger = logging.getLogger(__name__) def setup_logging(debug: bool = False): """Configure logging.""" level = logging.DEBUG if debug else logging.INFO logging.basicConfig( level=level, format="%(asctime)s - %(levelname)s - %(message)s", force=True, ) def cmd_process(args): """ Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant. Equivalent to process_pdfs_saliency_v2.py """ from visual_rag import CloudinaryUploader, QdrantIndexer, VisualEmbedder, load_config from visual_rag.indexing.pipeline import ProcessingPipeline # Load environment load_dotenv() # Load config config = {} if args.config and Path(args.config).exists(): config = load_config(args.config) # Get PDFs reports_dir = Path(args.reports_dir) if not reports_dir.exists(): logger.error(f"❌ Reports directory not found: {reports_dir}") sys.exit(1) pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF")) if not pdf_paths: logger.error(f"❌ No PDF files found in: {reports_dir}") sys.exit(1) logger.info(f"📁 Found {len(pdf_paths)} PDF files") # Load metadata mapping metadata_mapping = {} if args.metadata_file: metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file)) # Dry run - just show summary if args.dry_run: logger.info("🏃 DRY RUN MODE") logger.info(f" PDFs: {len(pdf_paths)}") logger.info(f" Metadata entries: {len(metadata_mapping)}") logger.info(f" Collection: {args.collection}") logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}") for pdf in pdf_paths[:10]: has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗" logger.info(f" {has_meta} {pdf.name}") if len(pdf_paths) > 10: logger.info(f" ... and {len(pdf_paths) - 10} more") return # Get settings model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M") collection_name = args.collection or config.get("qdrant", {}).get( "collection_name", "visual_documents" ) torch_dtype = None if args.torch_dtype != "auto": import torch torch_dtype = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, }[args.torch_dtype] logger.info(f"🤖 Initializing embedder: {model_name}") embedder = VisualEmbedder( model_name=model_name, batch_size=args.batch_size, torch_dtype=torch_dtype, processor_speed=str(getattr(args, "processor_speed", "fast")), ) # Initialize Qdrant indexer qdrant_url = ( os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") ) qdrant_api_key = ( os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") ) if not qdrant_url: logger.error("❌ QDRANT_URL environment variable not set") sys.exit(1) logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}") indexer = QdrantIndexer( url=qdrant_url, api_key=qdrant_api_key, collection_name=collection_name, prefer_grpc=args.prefer_grpc, vector_datatype=args.qdrant_vector_dtype, ) # Create collection if needed indexer.create_collection(force_recreate=args.force_recreate) inferred_fields = [] inferred_fields.append({"field": "filename", "type": "keyword"}) inferred_fields.append({"field": "page_number", "type": "integer"}) inferred_fields.append({"field": "has_text", "type": "bool"}) if metadata_mapping: keys = set() for _, meta in metadata_mapping.items(): if isinstance(meta, dict): keys.update(meta.keys()) for k in sorted(keys): if k in ("filename", "page_number", "has_text"): continue inferred_type = "keyword" for _, meta in metadata_mapping.items(): if not isinstance(meta, dict): continue v = meta.get(k) if isinstance(v, bool): inferred_type = "bool" break if isinstance(v, int): inferred_type = "integer" break if isinstance(v, float): inferred_type = "float" break inferred_fields.append({"field": k, "type": inferred_type}) indexer.create_payload_indexes(fields=inferred_fields) # Initialize Cloudinary uploader (optional) cloudinary_uploader = None if not args.no_cloudinary: try: project_name = config.get("project_name", "visual_docs") cloudinary_uploader = CloudinaryUploader(folder=project_name) except ValueError as e: logger.warning(f"⚠️ Cloudinary not configured: {e}") logger.warning(" Continuing without Cloudinary uploads") # Create pipeline pipeline = ProcessingPipeline( embedder=embedder, indexer=indexer, cloudinary_uploader=cloudinary_uploader, metadata_mapping=metadata_mapping, config=config, embedding_strategy=args.strategy, crop_empty=bool(getattr(args, "crop_empty", False)), crop_empty_percentage_to_remove=float( getattr(args, "crop_empty_percentage_to_remove", 0.9) ), crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)), ) # Process PDFs total_uploaded = 0 total_skipped = 0 total_failed = 0 skip_existing = not args.no_skip_existing for pdf_idx, pdf_path in enumerate(pdf_paths, 1): logger.info(f"\n{'='*60}") logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}") logger.info(f"{'='*60}") result = pipeline.process_pdf( pdf_path, skip_existing=skip_existing, upload_to_cloudinary=(not args.no_cloudinary), upload_to_qdrant=True, ) total_uploaded += result["uploaded"] total_skipped += result["skipped"] total_failed += result["failed"] # Summary logger.info(f"\n{'='*60}") logger.info("📊 SUMMARY") logger.info(f"{'='*60}") logger.info(f" Total PDFs: {len(pdf_paths)}") logger.info(f" Uploaded: {total_uploaded}") logger.info(f" Skipped: {total_skipped}") logger.info(f" Failed: {total_failed}") info = indexer.get_collection_info() if info: logger.info(f" Collection points: {info.get('points_count', 'N/A')}") def cmd_search(args): """Search documents.""" from qdrant_client import QdrantClient from visual_rag import VisualEmbedder from visual_rag.retrieval import SingleStageRetriever, TwoStageRetriever load_dotenv() qdrant_url = ( os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") ) qdrant_api_key = ( os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") ) if not qdrant_url: logger.error("❌ QDRANT_URL not set") sys.exit(1) # Initialize logger.info(f"🤖 Loading model: {args.model}") embedder = VisualEmbedder( model_name=args.model, processor_speed=str(getattr(args, "processor_speed", "fast")) ) logger.info("🔌 Connecting to Qdrant") grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None client = QdrantClient( url=qdrant_url, api_key=qdrant_api_key, prefer_grpc=args.prefer_grpc, grpc_port=grpc_port, check_compatibility=False, ) two_stage = TwoStageRetriever(client, args.collection) single_stage = SingleStageRetriever(client, args.collection) # Embed query logger.info(f"🔍 Query: {args.query}") query_embedding = embedder.embed_query(args.query) # Build filter filter_obj = None if args.year or args.source or args.district: filter_obj = two_stage.build_filter( year=args.year, source=args.source, district=args.district, ) # Search query_np = query_embedding.detach().cpu().float().numpy() # .float() for BFloat16 if args.strategy == "single_full": results = single_stage.search( query_embedding=query_np, top_k=args.top_k, strategy="multi_vector", filter_obj=filter_obj, ) elif args.strategy == "single_tiles": results = single_stage.search( query_embedding=query_np, top_k=args.top_k, strategy="tiles_maxsim", filter_obj=filter_obj, ) elif args.strategy == "single_global": results = single_stage.search( query_embedding=query_np, top_k=args.top_k, strategy="pooled_global", filter_obj=filter_obj, ) else: results = two_stage.search( query_embedding=query_np, top_k=args.top_k, prefetch_k=args.prefetch_k, filter_obj=filter_obj, stage1_mode=args.stage1_mode, ) # Display results logger.info(f"\n📊 Results ({len(results)}):") for i, result in enumerate(results, 1): payload = result.get("payload", {}) score = result.get("score_final", result.get("score_stage1", 0)) filename = payload.get("filename", "N/A") page_num = payload.get("page_number", "N/A") year = payload.get("year", "N/A") source = payload.get("source", "N/A") logger.info(f" {i}. {filename} p.{page_num}") logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}") # Text snippet text = payload.get("text", "") if text and args.show_text: snippet = text[:200].replace("\n", " ") logger.info(f" Text: {snippet}...") def cmd_info(args): """Show collection info.""" from qdrant_client import QdrantClient load_dotenv() qdrant_url = ( os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") ) qdrant_api_key = ( os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") ) if not qdrant_url: logger.error("❌ QDRANT_URL not set") sys.exit(1) grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None client = QdrantClient( url=qdrant_url, api_key=qdrant_api_key, prefer_grpc=args.prefer_grpc, grpc_port=grpc_port, check_compatibility=False, ) try: info = client.get_collection(args.collection) status = info.status if hasattr(status, "value"): status = status.value indexed_count = getattr(info, "indexed_vectors_count", 0) or 0 if isinstance(indexed_count, dict): indexed_count = sum(indexed_count.values()) logger.info(f"📊 Collection: {args.collection}") logger.info(f" Status: {status}") logger.info(f" Points: {info.points_count}") logger.info(f" Indexed vectors: {indexed_count}") # Show vector config if hasattr(info, "config") and hasattr(info.config, "params"): vectors = getattr(info.config.params, "vectors", {}) if vectors: logger.info(f" Vectors: {list(vectors.keys())}") except Exception as e: logger.error(f"❌ Could not get collection info: {e}") sys.exit(1) def main(): """Main CLI entry point.""" parser = argparse.ArgumentParser( prog="visual-rag", description="Visual RAG Toolkit - Visual document retrieval with ColPali", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Process PDFs (like process_pdfs_saliency_v2.py) visual-rag process --reports-dir ./pdfs --metadata-file metadata.json # Process without Cloudinary visual-rag process --reports-dir ./pdfs --no-cloudinary # Search visual-rag search --query "budget allocation" --collection my_docs # Search with filters visual-rag search --query "budget" --year 2023 --source "Local Government" # Show collection info visual-rag info --collection my_docs """, ) parser.add_argument("--debug", action="store_true", help="Enable debug logging") subparsers = parser.add_subparsers(dest="command", help="Command") # ========================================================================= # PROCESS command # ========================================================================= process_parser = subparsers.add_parser( "process", help="Process PDFs: embed, upload to Cloudinary, index in Qdrant", formatter_class=argparse.RawDescriptionHelpFormatter, ) process_parser.add_argument( "--reports-dir", type=str, required=True, help="Directory containing PDF files" ) process_parser.add_argument( "--metadata-file", type=str, help="JSON file with filename → metadata mapping (like filename_metadata.json)", ) process_parser.add_argument( "--collection", type=str, default="visual_documents", help="Qdrant collection name" ) process_parser.add_argument( "--model", type=str, default="vidore/colSmol-500M", help="Model name (vidore/colSmol-500M, vidore/colpali-v1.3, etc.)", ) process_parser.add_argument("--batch-size", type=int, default=8, help="Embedding batch size") process_parser.add_argument("--config", type=str, help="Path to config.yaml file") process_parser.add_argument( "--no-cloudinary", action="store_true", help="Skip Cloudinary uploads" ) process_parser.add_argument( "--crop-empty", action="store_true", help="Crop empty whitespace from page images before embedding (default: off).", ) process_parser.add_argument( "--crop-empty-percentage-to-remove", type=float, default=0.9, help="Kept for traceability; currently does not affect cropping behavior (default: 0.9).", ) process_parser.add_argument( "--crop-empty-remove-page-number", action="store_true", help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).", ) process_parser.add_argument( "--no-skip-existing", action="store_true", help="Process all pages even if they exist in Qdrant", ) process_parser.add_argument( "--force-recreate", action="store_true", help="Delete and recreate collection" ) process_parser.add_argument( "--dry-run", action="store_true", help="Show what would be processed without doing it" ) process_parser.add_argument( "--strategy", type=str, default="pooling", choices=["pooling", "standard", "all"], help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), " "'all' (embed once, store BOTH for comparison)", ) process_parser.add_argument( "--torch-dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"], help="Torch dtype for model weights (default: auto; CUDA->bfloat16, else float32).", ) process_parser.add_argument( "--qdrant-vector-dtype", type=str, default="float16", choices=["float16", "float32"], help="Datatype for vectors stored in Qdrant (default: float16).", ) process_parser.add_argument( "--processor-speed", type=str, default="fast", choices=["fast", "slow", "auto"], help="Processor implementation: fast (default, with fallback to slow), slow, or auto.", ) process_grpc_group = process_parser.add_mutually_exclusive_group() process_grpc_group.add_argument( "--prefer-grpc", dest="prefer_grpc", action="store_true", default=True, help="Use gRPC for Qdrant client (recommended).", ) process_grpc_group.add_argument( "--no-prefer-grpc", dest="prefer_grpc", action="store_false", help="Disable gRPC for Qdrant client.", ) process_parser.set_defaults(func=cmd_process) # ========================================================================= # SEARCH command # ========================================================================= search_parser = subparsers.add_parser( "search", help="Search documents", ) search_parser.add_argument("--query", type=str, required=True, help="Search query") search_parser.add_argument( "--collection", type=str, default="visual_documents", help="Qdrant collection name" ) search_parser.add_argument( "--model", type=str, default="vidore/colSmol-500M", help="Model name" ) search_parser.add_argument( "--processor-speed", type=str, default="fast", choices=["fast", "slow", "auto"], help="Processor implementation: fast (default, with fallback to slow), slow, or auto.", ) search_parser.add_argument("--top-k", type=int, default=10, help="Number of results") search_parser.add_argument( "--strategy", type=str, default="single_full", choices=["single_full", "single_tiles", "single_global", "two_stage"], help="Search strategy", ) search_parser.add_argument( "--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage retrieval" ) search_parser.add_argument( "--stage1-mode", type=str, default="pooled_query_vs_tiles", choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"], help="Stage 1 mode for two-stage retrieval", ) search_parser.add_argument("--year", type=int, help="Filter by year") search_parser.add_argument("--source", type=str, help="Filter by source") search_parser.add_argument("--district", type=str, help="Filter by district") search_parser.add_argument( "--show-text", action="store_true", help="Show text snippets in results" ) search_grpc_group = search_parser.add_mutually_exclusive_group() search_grpc_group.add_argument( "--prefer-grpc", dest="prefer_grpc", action="store_true", default=True, help="Use gRPC for Qdrant client (recommended).", ) search_grpc_group.add_argument( "--no-prefer-grpc", dest="prefer_grpc", action="store_false", help="Disable gRPC for Qdrant client.", ) search_parser.set_defaults(func=cmd_search) # ========================================================================= # INFO command # ========================================================================= info_parser = subparsers.add_parser( "info", help="Show collection info", ) info_parser.add_argument( "--collection", type=str, default="visual_documents", help="Qdrant collection name" ) info_grpc_group = info_parser.add_mutually_exclusive_group() info_grpc_group.add_argument( "--prefer-grpc", dest="prefer_grpc", action="store_true", default=True, help="Use gRPC for Qdrant client (recommended).", ) info_grpc_group.add_argument( "--no-prefer-grpc", dest="prefer_grpc", action="store_false", help="Disable gRPC for Qdrant client.", ) info_parser.set_defaults(func=cmd_info) # Parse and execute args = parser.parse_args() setup_logging(args.debug) if not args.command: parser.print_help() sys.exit(0) args.func(args) if __name__ == "__main__": main()