|
|
|
|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
from typing import Dict, List, Tuple |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from vector_db import SimpleVectorDB |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EnhancedRAGPipeline: |
|
|
def __init__(self, groq_client=None): |
|
|
self.groq_client = groq_client |
|
|
self.vector_db = None |
|
|
self.knowledge_base_file = "atlan_knowledge_base.json" |
|
|
self.vector_db_file = "atlan_vector_db.pkl" |
|
|
self.initialize_vector_db() |
|
|
|
|
|
def initialize_vector_db(self): |
|
|
self.vector_db = SimpleVectorDB() |
|
|
|
|
|
|
|
|
if not self.vector_db.load_database(): |
|
|
logger.info("No existing vector database found. Checking for knowledge base...") |
|
|
|
|
|
|
|
|
if Path(self.knowledge_base_file).exists(): |
|
|
logger.info("Found knowledge base. Building vector database...") |
|
|
if self.vector_db.load_knowledge_base(self.knowledge_base_file): |
|
|
self.vector_db.create_embeddings() |
|
|
self.vector_db.save_database() |
|
|
logger.info("Vector database built and saved") |
|
|
else: |
|
|
logger.error("Failed to load knowledge base") |
|
|
else: |
|
|
logger.warning("No knowledge base found. RAG will use fallback responses.") |
|
|
|
|
|
def is_rag_available(self) -> bool: |
|
|
"""Check if RAG system is properly initialized""" |
|
|
return self.vector_db is not None and len(self.vector_db.documents) > 0 |
|
|
|
|
|
def should_use_rag(self, topic_tags: List[str]) -> bool: |
|
|
"""Determine if RAG should be used based on topic tags""" |
|
|
rag_topics = ["How-to", "Product", "Best practices", "API/SDK", "SSO"] |
|
|
return any(tag in rag_topics for tag in topic_tags) |
|
|
|
|
|
def get_relevant_context(self, question: str, max_chars: int = 3000) -> Tuple[str, List[str]]: |
|
|
"""Get relevant context from the vector database""" |
|
|
if not self.is_rag_available(): |
|
|
return self._get_fallback_context(question), self._get_fallback_sources() |
|
|
|
|
|
try: |
|
|
context, sources = self.vector_db.get_context_for_query(question, max_chars) |
|
|
|
|
|
if not context: |
|
|
return self._get_fallback_context(question), self._get_fallback_sources() |
|
|
|
|
|
return context, sources |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error retrieving context: {str(e)}") |
|
|
return self._get_fallback_context(question), self._get_fallback_sources() |
|
|
|
|
|
def _get_fallback_context(self, question: str) -> str: |
|
|
"""Provide fallback context when vector DB is not available""" |
|
|
question_lower = question.lower() |
|
|
|
|
|
if "snowflake" in question_lower and "connect" in question_lower: |
|
|
return """ |
|
|
To connect Snowflake to Atlan: |
|
|
1. You need the following Snowflake permissions: USAGE on warehouse, database, and schema; SELECT on tables; MONITOR on warehouse |
|
|
2. Create a service account with these permissions |
|
|
3. In Atlan, go to Admin > Connectors > Add Snowflake |
|
|
4. Provide connection details: account URL, username, password, warehouse, database |
|
|
5. Test the connection and run the crawler |
|
|
|
|
|
Common issues: |
|
|
- Authentication failures: Check username/password and network access |
|
|
- Permission errors: Ensure service account has required privileges |
|
|
- Network issues: Verify Snowflake account URL and firewall settings |
|
|
""" |
|
|
|
|
|
elif "api" in question_lower or "sdk" in question_lower: |
|
|
return """ |
|
|
Atlan provides comprehensive APIs for programmatic access: |
|
|
|
|
|
REST API endpoints: |
|
|
- Assets API: Create, read, update assets |
|
|
- Search API: Search across the catalog |
|
|
- Lineage API: Retrieve lineage information |
|
|
- Glossary API: Manage business terms |
|
|
|
|
|
Authentication: Use API tokens (available in your profile settings) |
|
|
Base URL: https://your-tenant.atlan.com/api/meta |
|
|
|
|
|
Python SDK: pip install pyatlan |
|
|
Java SDK: Available via Maven Central |
|
|
|
|
|
Common operations: |
|
|
- Create assets: POST /entity/bulk |
|
|
- Search assets: POST /search/indexsearch |
|
|
- Get lineage: GET /lineage/{guid} |
|
|
""" |
|
|
|
|
|
elif "sso" in question_lower or "saml" in question_lower: |
|
|
return """ |
|
|
Setting up SSO with Atlan: |
|
|
|
|
|
SAML 2.0 Configuration: |
|
|
1. In Atlan Admin > Settings > Authentication |
|
|
2. Enable SAML SSO |
|
|
3. Configure Identity Provider details: |
|
|
- SSO URL, Entity ID, Certificate |
|
|
4. Map SAML attributes to Atlan user fields |
|
|
5. Test with a pilot user before full deployment |
|
|
|
|
|
Supported Identity Providers: |
|
|
- Okta, Azure AD, Google Workspace |
|
|
- Generic SAML 2.0 providers |
|
|
|
|
|
Troubleshooting: |
|
|
- Attribute mapping issues: Check SAML response format |
|
|
- Group assignment: Verify group claims in SAML assertions |
|
|
- Certificate errors: Ensure valid and properly formatted certificates |
|
|
""" |
|
|
|
|
|
elif "lineage" in question_lower: |
|
|
return """ |
|
|
Data Lineage in Atlan: |
|
|
|
|
|
Automatic lineage capture: |
|
|
- dbt: Connects via dbt Cloud or Core metadata |
|
|
- SQL-based tools: Snowflake, BigQuery, Redshift, etc. |
|
|
- ETL tools: Airflow, Fivetran, Matillion |
|
|
|
|
|
Manual lineage: |
|
|
- Use the lineage editor in the UI |
|
|
- API endpoints for programmatic lineage creation |
|
|
|
|
|
Lineage export: |
|
|
- Currently available through API calls |
|
|
- UI export features in development |
|
|
|
|
|
Troubleshooting missing lineage: |
|
|
- Check connector configuration |
|
|
- Verify SQL parsing is enabled |
|
|
- Review crawler logs for errors |
|
|
""" |
|
|
|
|
|
else: |
|
|
return """ |
|
|
Atlan is a modern data catalog that helps organizations: |
|
|
- Discover and understand their data assets |
|
|
- Implement data governance at scale |
|
|
- Enable self-service analytics |
|
|
- Ensure data quality and compliance |
|
|
|
|
|
Key features: |
|
|
- Automated metadata discovery |
|
|
- Data lineage visualization |
|
|
- Business glossary management |
|
|
- Data quality monitoring |
|
|
- Collaborative data stewardship |
|
|
""" |
|
|
|
|
|
def _get_fallback_sources(self) -> List[str]: |
|
|
"""Provide fallback sources when vector DB is not available""" |
|
|
return [ |
|
|
"https://docs.atlan.com/", |
|
|
"https://developer.atlan.com/", |
|
|
"https://docs.atlan.com/connectors/", |
|
|
"https://docs.atlan.com/guide/" |
|
|
] |
|
|
|
|
|
async def generate_answer(self, question: str, topic_tags: List[str]) -> Dict: |
|
|
"""Generate an answer using RAG pipeline""" |
|
|
|
|
|
if not self.should_use_rag(topic_tags): |
|
|
return { |
|
|
"type": "routing", |
|
|
"message": f"This ticket has been classified as a '{topic_tags[0] if topic_tags else 'General'}' issue and routed to the appropriate team." |
|
|
} |
|
|
|
|
|
|
|
|
context, sources = self.get_relevant_context(question) |
|
|
|
|
|
if not self.groq_client: |
|
|
|
|
|
return { |
|
|
"type": "direct_answer", |
|
|
"answer": f"Based on the documentation, here's information about your question: {context[:500]}...", |
|
|
"sources": sources |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
response = await self._generate_llm_response(question, context, sources) |
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating LLM response: {str(e)}") |
|
|
|
|
|
return { |
|
|
"type": "direct_answer", |
|
|
"answer": f"Based on the available documentation: {context[:800]}", |
|
|
"sources": sources |
|
|
} |
|
|
|
|
|
async def _generate_llm_response(self, question: str, context: str, sources: List[str]) -> Dict: |
|
|
"""Generate response using the LLM with retrieved context""" |
|
|
|
|
|
prompt = f""" |
|
|
You are an expert Atlan support agent. Use the provided documentation context to answer the user's question comprehensively and accurately. |
|
|
|
|
|
User Question: {question} |
|
|
|
|
|
Documentation Context: |
|
|
{context} |
|
|
|
|
|
Instructions: |
|
|
- Provide a direct, helpful, and detailed answer |
|
|
- Use the context to inform your response |
|
|
- Be specific about steps, requirements, and configurations when applicable |
|
|
- If the question is about troubleshooting, include common solutions |
|
|
- If the question is about setup/configuration, provide step-by-step guidance |
|
|
- Maintain a professional and helpful tone |
|
|
- Only use information from the provided context |
|
|
- If the context doesn't fully answer the question, acknowledge the limitation |
|
|
|
|
|
Format your response as a comprehensive answer that directly addresses the user's question. |
|
|
""" |
|
|
|
|
|
try: |
|
|
response = self.groq_client.chat.completions.create( |
|
|
model="openai/gpt-oss-120b", |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are an expert Atlan support agent. Provide helpful, accurate responses based on the documentation context."}, |
|
|
{"role": "user", "content": prompt} |
|
|
], |
|
|
temperature=0.2, |
|
|
max_tokens=1000 |
|
|
) |
|
|
|
|
|
answer = response.choices[0].message.content.strip() |
|
|
|
|
|
return { |
|
|
"type": "direct_answer", |
|
|
"answer": answer, |
|
|
"sources": sources |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"LLM generation failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
def setup_rag_system(): |
|
|
"""Setup the RAG system - run scraper if needed""" |
|
|
print("🤖 Setting up Enhanced RAG System...") |
|
|
print("=" * 45) |
|
|
|
|
|
|
|
|
kb_file = Path("atlan_knowledge_base.json") |
|
|
db_file = Path("atlan_vector_db.pkl") |
|
|
|
|
|
if not kb_file.exists(): |
|
|
print("📚 Knowledge base not found. Please run the scraper first:") |
|
|
print(" python scraper.py") |
|
|
return False |
|
|
|
|
|
if not db_file.exists(): |
|
|
print("🔧 Vector database not found. Building from knowledge base...") |
|
|
from vector_db import build_vector_database |
|
|
vector_db = build_vector_database() |
|
|
if not vector_db: |
|
|
print("❌ Failed to build vector database") |
|
|
return False |
|
|
|
|
|
print("✅ RAG system ready!") |
|
|
return True |
|
|
|
|
|
async def test_rag_pipeline(): |
|
|
"""Test the RAG pipeline""" |
|
|
print("\n🧪 Testing Enhanced RAG Pipeline...") |
|
|
print("=" * 40) |
|
|
|
|
|
|
|
|
rag = EnhancedRAGPipeline() |
|
|
|
|
|
test_questions = [ |
|
|
("How do I connect Snowflake to Atlan?", ["How-to", "Connector"]), |
|
|
("Show me API documentation for creating assets", ["API/SDK"]), |
|
|
("Our lineage is not showing up", ["Lineage", "Troubleshooting"]), |
|
|
("How to configure SAML SSO?", ["SSO", "How-to"]) |
|
|
] |
|
|
|
|
|
for question, topics in test_questions: |
|
|
print(f"\nQuestion: {question}") |
|
|
print(f"Topics: {topics}") |
|
|
|
|
|
result = await rag.generate_answer(question, topics) |
|
|
|
|
|
print(f"Response Type: {result['type']}") |
|
|
if result['type'] == 'direct_answer': |
|
|
print(f"Answer Length: {len(result['answer'])} characters") |
|
|
print(f"Sources: {len(result['sources'])}") |
|
|
print(f"Answer Preview: {result['answer'][:200]}...") |
|
|
else: |
|
|
print(f"Routing: {result['message']}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
if setup_rag_system(): |
|
|
asyncio.run(test_rag_pipeline()) |
|
|
else: |
|
|
print("❌ RAG system setup failed") |
|
|
|