Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Visual RAG Toolkit CLI | |
| Provides command-line interface for: | |
| - Processing PDFs (embedding, Cloudinary upload, Qdrant indexing) | |
| - Searching documents | |
| - Managing collections | |
| Usage: | |
| # Process PDFs (like process_pdfs_saliency_v2.py) | |
| visual-rag process --reports-dir ./pdfs --metadata-file metadata.json | |
| # Search | |
| visual-rag search --query "budget allocation" --collection my_docs | |
| # Show collection info | |
| visual-rag info --collection my_docs | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from urllib.parse import urlparse | |
| from dotenv import load_dotenv | |
| logger = logging.getLogger(__name__) | |
| def setup_logging(debug: bool = False): | |
| """Configure logging.""" | |
| level = logging.DEBUG if debug else logging.INFO | |
| logging.basicConfig( | |
| level=level, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| force=True, | |
| ) | |
| def cmd_process(args): | |
| """ | |
| Process PDFs: convert → embed → upload to Cloudinary → index in Qdrant. | |
| Equivalent to process_pdfs_saliency_v2.py | |
| """ | |
| from visual_rag import CloudinaryUploader, QdrantIndexer, VisualEmbedder, load_config | |
| from visual_rag.indexing.pipeline import ProcessingPipeline | |
| # Load environment | |
| load_dotenv() | |
| # Load config | |
| config = {} | |
| if args.config and Path(args.config).exists(): | |
| config = load_config(args.config) | |
| # Get PDFs | |
| reports_dir = Path(args.reports_dir) | |
| if not reports_dir.exists(): | |
| logger.error(f"❌ Reports directory not found: {reports_dir}") | |
| sys.exit(1) | |
| pdf_paths = sorted(reports_dir.glob("*.pdf")) + sorted(reports_dir.glob("*.PDF")) | |
| if not pdf_paths: | |
| logger.error(f"❌ No PDF files found in: {reports_dir}") | |
| sys.exit(1) | |
| logger.info(f"📁 Found {len(pdf_paths)} PDF files") | |
| # Load metadata mapping | |
| metadata_mapping = {} | |
| if args.metadata_file: | |
| metadata_mapping = ProcessingPipeline.load_metadata_mapping(Path(args.metadata_file)) | |
| # Dry run - just show summary | |
| if args.dry_run: | |
| logger.info("🏃 DRY RUN MODE") | |
| logger.info(f" PDFs: {len(pdf_paths)}") | |
| logger.info(f" Metadata entries: {len(metadata_mapping)}") | |
| logger.info(f" Collection: {args.collection}") | |
| logger.info(f" Cloudinary: {'ENABLED' if not args.no_cloudinary else 'DISABLED'}") | |
| for pdf in pdf_paths[:10]: | |
| has_meta = "✓" if pdf.stem.lower() in metadata_mapping else "✗" | |
| logger.info(f" {has_meta} {pdf.name}") | |
| if len(pdf_paths) > 10: | |
| logger.info(f" ... and {len(pdf_paths) - 10} more") | |
| return | |
| # Get settings | |
| model_name = args.model or config.get("model", {}).get("name", "vidore/colSmol-500M") | |
| collection_name = args.collection or config.get("qdrant", {}).get( | |
| "collection_name", "visual_documents" | |
| ) | |
| torch_dtype = None | |
| if args.torch_dtype != "auto": | |
| import torch | |
| torch_dtype = { | |
| "float32": torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| }[args.torch_dtype] | |
| logger.info(f"🤖 Initializing embedder: {model_name}") | |
| embedder = VisualEmbedder( | |
| model_name=model_name, | |
| batch_size=args.batch_size, | |
| torch_dtype=torch_dtype, | |
| processor_speed=str(getattr(args, "processor_speed", "fast")), | |
| ) | |
| # Initialize Qdrant indexer | |
| qdrant_url = ( | |
| os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") | |
| ) | |
| qdrant_api_key = ( | |
| os.getenv("SIGIR_QDRANT_KEY") | |
| or os.getenv("SIGIR_QDRANT_API_KEY") | |
| or os.getenv("DEST_QDRANT_API_KEY") | |
| or os.getenv("QDRANT_API_KEY") | |
| ) | |
| if not qdrant_url: | |
| logger.error("❌ QDRANT_URL environment variable not set") | |
| sys.exit(1) | |
| logger.info(f"🔌 Connecting to Qdrant: {qdrant_url}") | |
| indexer = QdrantIndexer( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| collection_name=collection_name, | |
| prefer_grpc=args.prefer_grpc, | |
| vector_datatype=args.qdrant_vector_dtype, | |
| ) | |
| # Create collection if needed | |
| indexer.create_collection(force_recreate=args.force_recreate) | |
| inferred_fields = [] | |
| inferred_fields.append({"field": "filename", "type": "keyword"}) | |
| inferred_fields.append({"field": "page_number", "type": "integer"}) | |
| inferred_fields.append({"field": "has_text", "type": "bool"}) | |
| if metadata_mapping: | |
| keys = set() | |
| for _, meta in metadata_mapping.items(): | |
| if isinstance(meta, dict): | |
| keys.update(meta.keys()) | |
| for k in sorted(keys): | |
| if k in ("filename", "page_number", "has_text"): | |
| continue | |
| inferred_type = "keyword" | |
| for _, meta in metadata_mapping.items(): | |
| if not isinstance(meta, dict): | |
| continue | |
| v = meta.get(k) | |
| if isinstance(v, bool): | |
| inferred_type = "bool" | |
| break | |
| if isinstance(v, int): | |
| inferred_type = "integer" | |
| break | |
| if isinstance(v, float): | |
| inferred_type = "float" | |
| break | |
| inferred_fields.append({"field": k, "type": inferred_type}) | |
| indexer.create_payload_indexes(fields=inferred_fields) | |
| # Initialize Cloudinary uploader (optional) | |
| cloudinary_uploader = None | |
| if not args.no_cloudinary: | |
| try: | |
| project_name = config.get("project_name", "visual_docs") | |
| cloudinary_uploader = CloudinaryUploader(folder=project_name) | |
| except ValueError as e: | |
| logger.warning(f"⚠️ Cloudinary not configured: {e}") | |
| logger.warning(" Continuing without Cloudinary uploads") | |
| # Create pipeline | |
| pipeline = ProcessingPipeline( | |
| embedder=embedder, | |
| indexer=indexer, | |
| cloudinary_uploader=cloudinary_uploader, | |
| metadata_mapping=metadata_mapping, | |
| config=config, | |
| embedding_strategy=args.strategy, | |
| crop_empty=bool(getattr(args, "crop_empty", False)), | |
| crop_empty_percentage_to_remove=float( | |
| getattr(args, "crop_empty_percentage_to_remove", 0.9) | |
| ), | |
| crop_empty_remove_page_number=bool(getattr(args, "crop_empty_remove_page_number", False)), | |
| ) | |
| # Process PDFs | |
| total_uploaded = 0 | |
| total_skipped = 0 | |
| total_failed = 0 | |
| skip_existing = not args.no_skip_existing | |
| for pdf_idx, pdf_path in enumerate(pdf_paths, 1): | |
| logger.info(f"\n{'='*60}") | |
| logger.info(f"📄 [{pdf_idx}/{len(pdf_paths)}] {pdf_path.name}") | |
| logger.info(f"{'='*60}") | |
| result = pipeline.process_pdf( | |
| pdf_path, | |
| skip_existing=skip_existing, | |
| upload_to_cloudinary=(not args.no_cloudinary), | |
| upload_to_qdrant=True, | |
| ) | |
| total_uploaded += result["uploaded"] | |
| total_skipped += result["skipped"] | |
| total_failed += result["failed"] | |
| # Summary | |
| logger.info(f"\n{'='*60}") | |
| logger.info("📊 SUMMARY") | |
| logger.info(f"{'='*60}") | |
| logger.info(f" Total PDFs: {len(pdf_paths)}") | |
| logger.info(f" Uploaded: {total_uploaded}") | |
| logger.info(f" Skipped: {total_skipped}") | |
| logger.info(f" Failed: {total_failed}") | |
| info = indexer.get_collection_info() | |
| if info: | |
| logger.info(f" Collection points: {info.get('points_count', 'N/A')}") | |
| def cmd_search(args): | |
| """Search documents.""" | |
| from qdrant_client import QdrantClient | |
| from visual_rag import VisualEmbedder | |
| from visual_rag.retrieval import SingleStageRetriever, TwoStageRetriever | |
| load_dotenv() | |
| qdrant_url = ( | |
| os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") | |
| ) | |
| qdrant_api_key = ( | |
| os.getenv("SIGIR_QDRANT_KEY") | |
| or os.getenv("SIGIR_QDRANT_API_KEY") | |
| or os.getenv("DEST_QDRANT_API_KEY") | |
| or os.getenv("QDRANT_API_KEY") | |
| ) | |
| if not qdrant_url: | |
| logger.error("❌ QDRANT_URL not set") | |
| sys.exit(1) | |
| # Initialize | |
| logger.info(f"🤖 Loading model: {args.model}") | |
| embedder = VisualEmbedder( | |
| model_name=args.model, processor_speed=str(getattr(args, "processor_speed", "fast")) | |
| ) | |
| logger.info("🔌 Connecting to Qdrant") | |
| grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None | |
| client = QdrantClient( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| prefer_grpc=args.prefer_grpc, | |
| grpc_port=grpc_port, | |
| check_compatibility=False, | |
| ) | |
| two_stage = TwoStageRetriever(client, args.collection) | |
| single_stage = SingleStageRetriever(client, args.collection) | |
| # Embed query | |
| logger.info(f"🔍 Query: {args.query}") | |
| query_embedding = embedder.embed_query(args.query) | |
| # Build filter | |
| filter_obj = None | |
| if args.year or args.source or args.district: | |
| filter_obj = two_stage.build_filter( | |
| year=args.year, | |
| source=args.source, | |
| district=args.district, | |
| ) | |
| # Search | |
| query_np = query_embedding.detach().cpu().float().numpy() # .float() for BFloat16 | |
| if args.strategy == "single_full": | |
| results = single_stage.search( | |
| query_embedding=query_np, | |
| top_k=args.top_k, | |
| strategy="multi_vector", | |
| filter_obj=filter_obj, | |
| ) | |
| elif args.strategy == "single_tiles": | |
| results = single_stage.search( | |
| query_embedding=query_np, | |
| top_k=args.top_k, | |
| strategy="tiles_maxsim", | |
| filter_obj=filter_obj, | |
| ) | |
| elif args.strategy == "single_global": | |
| results = single_stage.search( | |
| query_embedding=query_np, | |
| top_k=args.top_k, | |
| strategy="pooled_global", | |
| filter_obj=filter_obj, | |
| ) | |
| else: | |
| results = two_stage.search( | |
| query_embedding=query_np, | |
| top_k=args.top_k, | |
| prefetch_k=args.prefetch_k, | |
| filter_obj=filter_obj, | |
| stage1_mode=args.stage1_mode, | |
| ) | |
| # Display results | |
| logger.info(f"\n📊 Results ({len(results)}):") | |
| for i, result in enumerate(results, 1): | |
| payload = result.get("payload", {}) | |
| score = result.get("score_final", result.get("score_stage1", 0)) | |
| filename = payload.get("filename", "N/A") | |
| page_num = payload.get("page_number", "N/A") | |
| year = payload.get("year", "N/A") | |
| source = payload.get("source", "N/A") | |
| logger.info(f" {i}. {filename} p.{page_num}") | |
| logger.info(f" Score: {score:.4f} | Year: {year} | Source: {source}") | |
| # Text snippet | |
| text = payload.get("text", "") | |
| if text and args.show_text: | |
| snippet = text[:200].replace("\n", " ") | |
| logger.info(f" Text: {snippet}...") | |
| def cmd_info(args): | |
| """Show collection info.""" | |
| from qdrant_client import QdrantClient | |
| load_dotenv() | |
| qdrant_url = ( | |
| os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") | |
| ) | |
| qdrant_api_key = ( | |
| os.getenv("SIGIR_QDRANT_KEY") | |
| or os.getenv("SIGIR_QDRANT_API_KEY") | |
| or os.getenv("DEST_QDRANT_API_KEY") | |
| or os.getenv("QDRANT_API_KEY") | |
| ) | |
| if not qdrant_url: | |
| logger.error("❌ QDRANT_URL not set") | |
| sys.exit(1) | |
| grpc_port = 6334 if args.prefer_grpc and urlparse(qdrant_url).port == 6333 else None | |
| client = QdrantClient( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| prefer_grpc=args.prefer_grpc, | |
| grpc_port=grpc_port, | |
| check_compatibility=False, | |
| ) | |
| try: | |
| info = client.get_collection(args.collection) | |
| status = info.status | |
| if hasattr(status, "value"): | |
| status = status.value | |
| indexed_count = getattr(info, "indexed_vectors_count", 0) or 0 | |
| if isinstance(indexed_count, dict): | |
| indexed_count = sum(indexed_count.values()) | |
| logger.info(f"📊 Collection: {args.collection}") | |
| logger.info(f" Status: {status}") | |
| logger.info(f" Points: {info.points_count}") | |
| logger.info(f" Indexed vectors: {indexed_count}") | |
| # Show vector config | |
| if hasattr(info, "config") and hasattr(info.config, "params"): | |
| vectors = getattr(info.config.params, "vectors", {}) | |
| if vectors: | |
| logger.info(f" Vectors: {list(vectors.keys())}") | |
| except Exception as e: | |
| logger.error(f"❌ Could not get collection info: {e}") | |
| sys.exit(1) | |
| def main(): | |
| """Main CLI entry point.""" | |
| parser = argparse.ArgumentParser( | |
| prog="visual-rag", | |
| description="Visual RAG Toolkit - Visual document retrieval with ColPali", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Process PDFs (like process_pdfs_saliency_v2.py) | |
| visual-rag process --reports-dir ./pdfs --metadata-file metadata.json | |
| # Process without Cloudinary | |
| visual-rag process --reports-dir ./pdfs --no-cloudinary | |
| # Search | |
| visual-rag search --query "budget allocation" --collection my_docs | |
| # Search with filters | |
| visual-rag search --query "budget" --year 2023 --source "Local Government" | |
| # Show collection info | |
| visual-rag info --collection my_docs | |
| """, | |
| ) | |
| parser.add_argument("--debug", action="store_true", help="Enable debug logging") | |
| subparsers = parser.add_subparsers(dest="command", help="Command") | |
| # ========================================================================= | |
| # PROCESS command | |
| # ========================================================================= | |
| process_parser = subparsers.add_parser( | |
| "process", | |
| help="Process PDFs: embed, upload to Cloudinary, index in Qdrant", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| process_parser.add_argument( | |
| "--reports-dir", type=str, required=True, help="Directory containing PDF files" | |
| ) | |
| process_parser.add_argument( | |
| "--metadata-file", | |
| type=str, | |
| help="JSON file with filename → metadata mapping (like filename_metadata.json)", | |
| ) | |
| process_parser.add_argument( | |
| "--collection", type=str, default="visual_documents", help="Qdrant collection name" | |
| ) | |
| process_parser.add_argument( | |
| "--model", | |
| type=str, | |
| default="vidore/colSmol-500M", | |
| help="Model name (vidore/colSmol-500M, vidore/colpali-v1.3, etc.)", | |
| ) | |
| process_parser.add_argument("--batch-size", type=int, default=8, help="Embedding batch size") | |
| process_parser.add_argument("--config", type=str, help="Path to config.yaml file") | |
| process_parser.add_argument( | |
| "--no-cloudinary", action="store_true", help="Skip Cloudinary uploads" | |
| ) | |
| process_parser.add_argument( | |
| "--crop-empty", | |
| action="store_true", | |
| help="Crop empty whitespace from page images before embedding (default: off).", | |
| ) | |
| process_parser.add_argument( | |
| "--crop-empty-percentage-to-remove", | |
| type=float, | |
| default=0.9, | |
| help="Kept for traceability; currently does not affect cropping behavior (default: 0.9).", | |
| ) | |
| process_parser.add_argument( | |
| "--crop-empty-remove-page-number", | |
| action="store_true", | |
| help="If set, attempts to crop away the bottom region that contains sparse page numbers (default: off).", | |
| ) | |
| process_parser.add_argument( | |
| "--no-skip-existing", | |
| action="store_true", | |
| help="Process all pages even if they exist in Qdrant", | |
| ) | |
| process_parser.add_argument( | |
| "--force-recreate", action="store_true", help="Delete and recreate collection" | |
| ) | |
| process_parser.add_argument( | |
| "--dry-run", action="store_true", help="Show what would be processed without doing it" | |
| ) | |
| process_parser.add_argument( | |
| "--strategy", | |
| type=str, | |
| default="pooling", | |
| choices=["pooling", "standard", "all"], | |
| help="Embedding strategy: 'pooling' (NOVEL), 'standard' (BASELINE), " | |
| "'all' (embed once, store BOTH for comparison)", | |
| ) | |
| process_parser.add_argument( | |
| "--torch-dtype", | |
| type=str, | |
| default="auto", | |
| choices=["auto", "float32", "float16", "bfloat16"], | |
| help="Torch dtype for model weights (default: auto; CUDA->bfloat16, else float32).", | |
| ) | |
| process_parser.add_argument( | |
| "--qdrant-vector-dtype", | |
| type=str, | |
| default="float16", | |
| choices=["float16", "float32"], | |
| help="Datatype for vectors stored in Qdrant (default: float16).", | |
| ) | |
| process_parser.add_argument( | |
| "--processor-speed", | |
| type=str, | |
| default="fast", | |
| choices=["fast", "slow", "auto"], | |
| help="Processor implementation: fast (default, with fallback to slow), slow, or auto.", | |
| ) | |
| process_grpc_group = process_parser.add_mutually_exclusive_group() | |
| process_grpc_group.add_argument( | |
| "--prefer-grpc", | |
| dest="prefer_grpc", | |
| action="store_true", | |
| default=True, | |
| help="Use gRPC for Qdrant client (recommended).", | |
| ) | |
| process_grpc_group.add_argument( | |
| "--no-prefer-grpc", | |
| dest="prefer_grpc", | |
| action="store_false", | |
| help="Disable gRPC for Qdrant client.", | |
| ) | |
| process_parser.set_defaults(func=cmd_process) | |
| # ========================================================================= | |
| # SEARCH command | |
| # ========================================================================= | |
| search_parser = subparsers.add_parser( | |
| "search", | |
| help="Search documents", | |
| ) | |
| search_parser.add_argument("--query", type=str, required=True, help="Search query") | |
| search_parser.add_argument( | |
| "--collection", type=str, default="visual_documents", help="Qdrant collection name" | |
| ) | |
| search_parser.add_argument( | |
| "--model", type=str, default="vidore/colSmol-500M", help="Model name" | |
| ) | |
| search_parser.add_argument( | |
| "--processor-speed", | |
| type=str, | |
| default="fast", | |
| choices=["fast", "slow", "auto"], | |
| help="Processor implementation: fast (default, with fallback to slow), slow, or auto.", | |
| ) | |
| search_parser.add_argument("--top-k", type=int, default=10, help="Number of results") | |
| search_parser.add_argument( | |
| "--strategy", | |
| type=str, | |
| default="single_full", | |
| choices=["single_full", "single_tiles", "single_global", "two_stage"], | |
| help="Search strategy", | |
| ) | |
| search_parser.add_argument( | |
| "--prefetch-k", type=int, default=200, help="Prefetch candidates for two-stage retrieval" | |
| ) | |
| search_parser.add_argument( | |
| "--stage1-mode", | |
| type=str, | |
| default="pooled_query_vs_tiles", | |
| choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"], | |
| help="Stage 1 mode for two-stage retrieval", | |
| ) | |
| search_parser.add_argument("--year", type=int, help="Filter by year") | |
| search_parser.add_argument("--source", type=str, help="Filter by source") | |
| search_parser.add_argument("--district", type=str, help="Filter by district") | |
| search_parser.add_argument( | |
| "--show-text", action="store_true", help="Show text snippets in results" | |
| ) | |
| search_grpc_group = search_parser.add_mutually_exclusive_group() | |
| search_grpc_group.add_argument( | |
| "--prefer-grpc", | |
| dest="prefer_grpc", | |
| action="store_true", | |
| default=True, | |
| help="Use gRPC for Qdrant client (recommended).", | |
| ) | |
| search_grpc_group.add_argument( | |
| "--no-prefer-grpc", | |
| dest="prefer_grpc", | |
| action="store_false", | |
| help="Disable gRPC for Qdrant client.", | |
| ) | |
| search_parser.set_defaults(func=cmd_search) | |
| # ========================================================================= | |
| # INFO command | |
| # ========================================================================= | |
| info_parser = subparsers.add_parser( | |
| "info", | |
| help="Show collection info", | |
| ) | |
| info_parser.add_argument( | |
| "--collection", type=str, default="visual_documents", help="Qdrant collection name" | |
| ) | |
| info_grpc_group = info_parser.add_mutually_exclusive_group() | |
| info_grpc_group.add_argument( | |
| "--prefer-grpc", | |
| dest="prefer_grpc", | |
| action="store_true", | |
| default=True, | |
| help="Use gRPC for Qdrant client (recommended).", | |
| ) | |
| info_grpc_group.add_argument( | |
| "--no-prefer-grpc", | |
| dest="prefer_grpc", | |
| action="store_false", | |
| help="Disable gRPC for Qdrant client.", | |
| ) | |
| info_parser.set_defaults(func=cmd_info) | |
| # Parse and execute | |
| args = parser.parse_args() | |
| setup_logging(args.debug) | |
| if not args.command: | |
| parser.print_help() | |
| sys.exit(0) | |
| args.func(args) | |
| if __name__ == "__main__": | |
| main() | |