| |
| from rag_utils import RAGSystem |
| import argparse |
| import os |
| import logging |
| import shutil |
|
|
| |
| logging.basicConfig(level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train a RAG system on your documents") |
| parser.add_argument("--pdf", type=str, help="Path to LinkedIn PDF file", default="me/linkedin.pdf") |
| parser.add_argument("--summary", type=str, help="Path to summary text file", default="me/summary.txt") |
| parser.add_argument("--output", type=str, help="Directory to save the RAG index", default="me/rag_index") |
| parser.add_argument("--test", action="store_true", help="Run a test query after training") |
| parser.add_argument("--force", action="store_true", help="Force rebuild even if index exists") |
| |
| args = parser.parse_args() |
| |
| |
| if os.path.exists(args.output): |
| if not args.force: |
| logger.warning(f"Output directory {args.output} already exists.") |
| choice = input("Do you want to (o)verwrite, (s)kip to testing, or (c)ancel? [o/s/c]: ").lower() |
| if choice == 'c': |
| logger.info("Operation cancelled.") |
| return |
| elif choice == 's': |
| |
| logger.info("Skipping build, using existing index.") |
| if args.test: |
| try: |
| test_index(args.output) |
| except Exception as e: |
| logger.error(f"Error testing index: {e}") |
| return |
| elif choice == 'o': |
| logger.info(f"Removing existing directory {args.output}...") |
| shutil.rmtree(args.output) |
| else: |
| logger.error("Invalid choice. Exiting.") |
| return |
| else: |
| logger.info(f"Force flag set. Removing existing directory {args.output}...") |
| shutil.rmtree(args.output) |
| |
| |
| if not os.path.exists(args.pdf): |
| logger.error(f"Error: PDF file not found at {args.pdf}") |
| return |
| |
| if not os.path.exists(args.summary): |
| logger.error(f"Error: Summary file not found at {args.summary}") |
| return |
| |
| logger.info("Initializing RAG system...") |
| rag = RAGSystem() |
| |
| |
| logger.info(f"Processing LinkedIn profile from {args.pdf}...") |
| try: |
| linkedin_count = rag.add_document(args.pdf, "LinkedIn Profile") |
| logger.info(f"Added {linkedin_count} chunks from LinkedIn profile") |
| except Exception as e: |
| logger.error(f"Error processing LinkedIn PDF: {e}") |
| return |
| |
| logger.info(f"Processing professional summary from {args.summary}...") |
| try: |
| summary_count = rag.add_document(args.summary, "Professional Summary") |
| logger.info(f"Added {summary_count} chunks from professional summary") |
| except Exception as e: |
| logger.error(f"Error processing summary file: {e}") |
| return |
| |
| |
| if len(rag.chunks) == 0: |
| logger.error("No chunks were created. Check your input files.") |
| return |
| |
| if rag.index is None or rag.index.ntotal == 0: |
| logger.error("No index was created. Check your input files.") |
| return |
| |
| |
| logger.info(f"Saving RAG index to {args.output}...") |
| try: |
| rag.save_index(args.output) |
| logger.info("RAG index saved successfully!") |
| except Exception as e: |
| logger.error(f"Error saving index: {e}") |
| return |
| |
| |
| if args.test: |
| test_index(args.output) |
| |
| logger.info("\nRAG system training complete!") |
| logger.info(f"To use this RAG system in your application, load it from: {args.output}") |
|
|
| def test_index(index_dir): |
| """Test the index with sample queries""" |
| try: |
| logger.info("Loading index for testing...") |
| rag = RAGSystem() |
| rag.load_index(index_dir) |
| |
| logger.info(f"Loaded index with {len(rag.chunks)} chunks") |
| |
| queries = [ |
| "What are Sagarnil's technical skills?", |
| "What is Sagarnil's work experience?", |
| "What educational background does Sagarnil have?" |
| ] |
| |
| logger.info("\nTesting RAG system with sample queries:") |
| for query in queries: |
| logger.info(f"\nQUERY: {query}") |
| context = rag.get_context_for_query(query) |
| logger.info(context) |
| |
| logger.info("Testing complete!") |
| except Exception as e: |
| logger.error(f"Error during testing: {e}") |
| raise |
|
|
| if __name__ == "__main__": |
| main() |