Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| import os | |
| from typing import List, Dict | |
| from dotenv import load_dotenv | |
| import logging | |
| from pathlib import Path | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Qdrant as QdrantVectorStore | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from langchain_groq import ChatGroq | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| from qdrant_client.models import PointIdsList | |
| from langgraph.graph import MessagesState, StateGraph | |
| from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.graph import END | |
| from langgraph.prebuilt import tools_condition | |
| from langgraph.checkpoint.memory import MemorySaver | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
| GROQ_API_KEY = os.getenv('GROQ_API_KEY') | |
| if not GOOGLE_API_KEY or not GROQ_API_KEY: | |
| raise ValueError("API keys not set in environment variables") | |
| app = FastAPI() | |
| class QASystem: | |
| def __init__(self): | |
| self.vector_store = None | |
| self.graph = None | |
| self.memory = None | |
| self.embeddings = None | |
| self.client = None | |
| self.pdf_dir = "pdfss" | |
| def load_pdf_documents(self): | |
| documents = [] | |
| pdf_dir = Path(self.pdf_dir) | |
| if not pdf_dir.exists(): | |
| raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}") | |
| for pdf_path in pdf_dir.glob("*.pdf"): | |
| try: | |
| loader = PyPDFLoader(str(pdf_path)) | |
| documents.extend(loader.load()) | |
| logger.info(f"Loaded PDF: {pdf_path}") | |
| except Exception as e: | |
| logger.error(f"Error loading PDF {pdf_path}: {str(e)}") | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=100 | |
| ) | |
| split_docs = text_splitter.split_documents(documents) | |
| logger.info(f"Split documents into {len(split_docs)} chunks") | |
| return split_docs | |
| def initialize_system(self): | |
| try: | |
| self.client = QdrantClient(":memory:") | |
| try: | |
| self.client.get_collection("pdf_data") | |
| except Exception: | |
| self.client.create_collection( | |
| collection_name="pdf_data", | |
| vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
| ) | |
| logger.info("Created new collection: pdf_data") | |
| self.embeddings = GoogleGenerativeAIEmbeddings( | |
| model="models/embedding-001", | |
| google_api_key=GOOGLE_API_KEY | |
| ) | |
| self.vector_store = QdrantVectorStore( | |
| client=self.client, | |
| collection_name="pdf_data", | |
| embeddings=self.embeddings, | |
| ) | |
| documents = self.load_pdf_documents() | |
| if documents: | |
| try: | |
| points = self.client.scroll(collection_name="pdf_data", limit=100)[0] | |
| if points: | |
| self.client.delete( | |
| collection_name="pdf_data", | |
| points_selector=PointIdsList( | |
| points=[p.id for p in points] | |
| ) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error clearing vectors: {str(e)}") | |
| self.vector_store.add_documents(documents) | |
| logger.info(f"Added {len(documents)} documents to vector store") | |
| llm = ChatGroq( | |
| model="llama3-8b-8192", | |
| api_key=GROQ_API_KEY, | |
| temperature=0.7 | |
| ) | |
| graph_builder = StateGraph(MessagesState) | |
| # Define a retrieval node that fetches relevant docs | |
| def retrieve_docs(state: MessagesState): | |
| # Get the most recent human message | |
| human_messages = [m for m in state["messages"] if m.type == "human"] | |
| if not human_messages: | |
| return {"messages": state["messages"]} | |
| user_query = human_messages[-1].content | |
| logger.info(f"Retrieving documents for query: {user_query}") | |
| # Query the vector store | |
| try: | |
| retrieved_docs = self.vector_store.similarity_search(user_query, k=3) | |
| # Create tool messages for each retrieved document | |
| tool_messages = [] | |
| for i, doc in enumerate(retrieved_docs): | |
| tool_messages.append( | |
| ToolMessage( | |
| content=f"Document {i+1}: {doc.page_content}", | |
| tool_call_id=f"retrieval_{i}" | |
| ) | |
| ) | |
| logger.info(f"Retrieved {len(tool_messages)} relevant documents") | |
| return {"messages": state["messages"] + tool_messages} | |
| except Exception as e: | |
| logger.error(f"Error retrieving documents: {str(e)}") | |
| return {"messages": state["messages"]} | |
| # Updated generate function that uses retrieved documents | |
| def generate(state: MessagesState): | |
| # Extract retrieved documents (tool messages) | |
| tool_messages = [m for m in state["messages"] if m.type == "tool"] | |
| # Collect context from retrieved documents | |
| if tool_messages: | |
| context = "\n".join([m.content for m in tool_messages]) | |
| logger.info(f"Using context from {len(tool_messages)} retrieved documents") | |
| else: | |
| context = "No specific mountain bicycle documentation available." | |
| logger.info("No relevant documents retrieved, using default context") | |
| system_prompt = ( | |
| "You are an AI assistant embedded within the Interactive Electronic Technical Manual (IETM) for Mountain Cycles. " | |
| "Always provide accurate responses with references to provided data. " | |
| "If the user query is not technical-specific, still respond from a IETM perspective." | |
| f"\n\nContext from mountain bicycle documentation:\n{context}" | |
| ) | |
| # Get all messages excluding tool messages to avoid redundancy | |
| human_and_ai_messages = [m for m in state["messages"] if m.type != "tool"] | |
| # Create the full message history for the LLM | |
| messages = [SystemMessage(content=system_prompt)] + human_and_ai_messages | |
| logger.info(f"Sending query to LLM with {len(messages)} messages") | |
| # Generate the response | |
| response = llm.invoke(messages) | |
| return {"messages": state["messages"] + [response]} | |
| # Add nodes to the graph | |
| graph_builder.add_node("retrieve_docs", retrieve_docs) | |
| graph_builder.add_node("generate", generate) | |
| # Set the flow of the graph | |
| graph_builder.set_entry_point("retrieve_docs") | |
| graph_builder.add_edge("retrieve_docs", "generate") | |
| graph_builder.add_edge("generate", END) | |
| self.memory = MemorySaver() | |
| self.graph = graph_builder.compile(checkpointer=self.memory) | |
| return True | |
| except Exception as e: | |
| logger.error(f"System initialization error: {str(e)}") | |
| return False | |
| def process_query(self, query: str) -> Dict[str, str]: | |
| """Process a query and return a single final response""" | |
| try: | |
| # Generate a unique thread ID for production use | |
| # For simplicity, using a fixed ID here | |
| thread_id = "abc123" | |
| # Use invoke instead of stream to get only the final result | |
| final_state = self.graph.invoke( | |
| {"messages": [HumanMessage(content=query)]}, | |
| config={"configurable": {"thread_id": thread_id}} | |
| ) | |
| # Extract only the last AI message from the final state | |
| ai_messages = [m for m in final_state["messages"] if m.type == "ai"] | |
| if ai_messages: | |
| # Return only the last AI message | |
| return { | |
| 'content': ai_messages[-1].content, | |
| 'type': ai_messages[-1].type | |
| } | |
| return { | |
| 'content': "No response generated", | |
| 'type': 'error' | |
| } | |
| except Exception as e: | |
| logger.error(f"Query processing error: {str(e)}") | |
| return { | |
| 'content': f"Query processing error: {str(e)}", | |
| 'type': 'error' | |
| } | |
| qa_system = QASystem() | |
| if qa_system.initialize_system(): | |
| logger.info("QA System Initialized Successfully") | |
| else: | |
| raise RuntimeError("Failed to initialize QA System") | |
| async def query_api(query: str): | |
| """API endpoint that returns a single response for a query""" | |
| response = qa_system.process_query(query) | |
| return {"response": response} |