Spaces:
Sleeping
Sleeping
| import os | |
| from rag_pipeline import load_repo, create_vectorstore, create_qa_chain, query | |
| from dotenv import load_dotenv | |
| # Load local environment variables for OpenAI API Key (if any) | |
| load_dotenv() | |
| def test_local_rag(repo_url: str): | |
| """ | |
| Test the RAG pipeline with sample questions. | |
| """ | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| print("β Error: Set OPENAI_API_KEY env variable or add to .env file.") | |
| return | |
| print(f"π Initializing RAG pipeline for: {repo_url}...") | |
| try: | |
| # 1. Load Repo | |
| print("π₯ Loading repository...") | |
| documents = load_repo(repo_url) | |
| # 2. Create Vectorstore | |
| print("π§ Creating vector store and processing embeddings...") | |
| vectorstore = create_vectorstore(documents) | |
| # 3. Create QA Chain | |
| print("π Setting up QA chain...") | |
| qa_chain = create_qa_chain(vectorstore, openai_api_key) | |
| # 4. Test Sample Questions | |
| questions = [ | |
| "How does shared memory tiling work?", | |
| "What is tree reduction?", | |
| "How does CUDA optimization work?", | |
| "Explain the backward pass", | |
| "What are the main functions?" | |
| ] | |
| print("\n--- π§ͺ Starting local tests ---\n") | |
| for q in questions: | |
| print(f"β QUESTION: {q}") | |
| answer, sources = query(qa_chain, q) | |
| print("-" * 30) | |
| print(f"π‘ AI ANSWER:\n{answer}") | |
| print("-" * 15) | |
| print(f"π Top source chunks retrieved:") | |
| for i, src in enumerate(sources[:2]): # Show only top 2 chunks | |
| path = src.metadata.get('file_path', 'Unknown') | |
| print(f" - Chunk {i+1} from {path}") | |
| print("\n" + "="*50 + "\n") | |
| except Exception as e: | |
| print(f"β An error occurred during testing: {e}") | |
| if __name__ == "__main__": | |
| # You can change this to any public repo you want to test | |
| # (e.g., https://github.com/MacroTorch/MacroTorch if it were public) | |
| # Using a common small repo for demonstration if MacroTorch is unavailable | |
| DEFAULT_TEST_REPO = "https://github.com/MacroTorch/MacroTorch" | |
| # (Note: langchain repo might be too large for a quick test, maybe something smaller) | |
| test_local_rag(DEFAULT_TEST_REPO) | |