Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import requests | |
| import tempfile | |
| from google.oauth2 import service_account | |
| from googleapiclient.discovery import build | |
| from googleapiclient.http import MediaIoBaseDownload | |
| import openai | |
| from dotenv import load_dotenv, dotenv_values | |
| import io | |
| import logging | |
| from typing import List, Dict, Optional | |
| # LangChain imports | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.docstore.document import Document | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.schema import BaseRetriever | |
| import pickle | |
| import hashlib | |
| from openai import OpenAI | |
| openai.api_key = os.getenv('OPENAI_API_KEY') | |
| openai = OpenAI(api_key=openai.api_key) | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class EnhancedGPTDriveIntegration: | |
| def __init__(self): | |
| # Build credentials info from individual environment variables | |
| credentials_info = { | |
| "type": "service_account", | |
| "project_id": os.getenv('GOOGLE_PROJECT_ID'), | |
| "private_key_id": os.getenv('GOOGLE_PRIVATE_KEY_ID'), | |
| "private_key": os.getenv('GOOGLE_PRIVATE_KEY').replace('\\n', '\n'), | |
| "client_email": os.getenv('GOOGLE_CLIENT_EMAIL'), | |
| "client_id": os.getenv('GOOGLE_CLIENT_ID'), | |
| "auth_uri": "https://accounts.google.com/o/oauth2/auth", | |
| "token_uri": "https://oauth2.googleapis.com/token", | |
| "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", | |
| "client_x509_cert_url": os.getenv('GOOGLE_CLIENT_CERT_URL'), | |
| "universe_domain": "googleapis.com" | |
| } | |
| # Check if all required fields are present | |
| required_fields = ['project_id', 'private_key', 'client_email'] | |
| missing_fields = [field for field in required_fields if not credentials_info[field]] | |
| if missing_fields: | |
| raise ValueError(f"Missing required environment variables: {missing_fields}") | |
| # Initialize Google Drive API | |
| self.credentials = service_account.Credentials.from_service_account_info( | |
| credentials_info, | |
| scopes=['https://www.googleapis.com/auth/drive.readonly'] | |
| ) | |
| self.drive_service = build('drive', 'v3', credentials=self.credentials) | |
| # Initialize OpenAI and LangChain components | |
| openai.api_key = os.getenv('OPENAI_API_KEY') | |
| self.embeddings = OpenAIEmbeddings() | |
| self.llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo") | |
| # Text splitter for better chunking | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| # Initialize vector store | |
| self.vector_store = None | |
| self.conversation_memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| # Cache for processed files | |
| self.processed_files = {} | |
| self.cache_file = "processed_files_cache.pkl" | |
| self.load_cache() | |
| def load_cache(self): | |
| """Load processed files cache""" | |
| try: | |
| if os.path.exists(self.cache_file): | |
| with open(self.cache_file, 'rb') as f: | |
| self.processed_files = pickle.load(f) | |
| logger.info(f"Loaded cache with {len(self.processed_files)} files") | |
| except Exception as e: | |
| logger.error(f"Error loading cache: {e}") | |
| self.processed_files = {} | |
| def save_cache(self): | |
| """Save processed files cache""" | |
| try: | |
| with open(self.cache_file, 'wb') as f: | |
| pickle.dump(self.processed_files, f) | |
| logger.info("Cache saved successfully") | |
| except Exception as e: | |
| logger.error(f"Error saving cache: {e}") | |
| def get_file_hash(self, file_id: str, file_size: str) -> str: | |
| """Generate hash for file to check if it's been processed""" | |
| return hashlib.md5(f"{file_id}_{file_size}".encode()).hexdigest() | |
| def search_files(self, query: str, file_types: Optional[List[str]] = None) -> List[Dict]: | |
| """Search for files in Google Drive with improved query handling""" | |
| # Build more sophisticated search query | |
| search_terms = query.lower().split() | |
| search_queries = [] | |
| # Search in file names and content | |
| for term in search_terms: | |
| search_queries.append(f"name contains '{term}' or fullText contains '{term}'") | |
| search_query = " and ".join([f"({sq})" for sq in search_queries]) | |
| if file_types: | |
| type_queries = [] | |
| for file_type in file_types: | |
| if file_type.lower() == 'pdf': | |
| type_queries.append("mimeType='application/pdf'") | |
| elif file_type.lower() in ['doc', 'docx']: | |
| type_queries.append("mimeType contains 'document'") | |
| elif file_type.lower() in ['xls', 'xlsx']: | |
| type_queries.append("mimeType contains 'spreadsheet'") | |
| elif file_type.lower() == 'txt': | |
| type_queries.append("mimeType='text/plain'") | |
| if type_queries: | |
| search_query += f" and ({' or '.join(type_queries)})" | |
| try: | |
| results = self.drive_service.files().list( | |
| q=search_query, | |
| fields="files(id, name, mimeType, size, modifiedTime)", | |
| pageSize=20 # Increased to get more results | |
| ).execute() | |
| files = results.get('files', []) | |
| logger.info(f"Found {len(files)} files matching query: {query}") | |
| return files | |
| except Exception as e: | |
| logger.error(f"Error searching files: {e}") | |
| return [] | |
| def get_file_content(self, file_id: str, mime_type: str) -> str: | |
| """Download and extract text content from file with better error handling""" | |
| try: | |
| if 'text' in mime_type or 'document' in mime_type: | |
| if 'document' in mime_type: | |
| request = self.drive_service.files().export_media( | |
| fileId=file_id, mimeType='text/plain' | |
| ) | |
| else: | |
| request = self.drive_service.files().get_media(fileId=file_id) | |
| file_content = io.BytesIO() | |
| downloader = MediaIoBaseDownload(file_content, request) | |
| done = False | |
| while done is False: | |
| status, done = downloader.next_chunk() | |
| return file_content.getvalue().decode('utf-8', errors='ignore') | |
| elif 'spreadsheet' in mime_type: | |
| request = self.drive_service.files().export_media( | |
| fileId=file_id, mimeType='text/csv' | |
| ) | |
| file_content = io.BytesIO() | |
| downloader = MediaIoBaseDownload(file_content, request) | |
| done = False | |
| while done is False: | |
| status, done = downloader.next_chunk() | |
| return file_content.getvalue().decode('utf-8', errors='ignore') | |
| elif mime_type == 'application/pdf': | |
| request = self.drive_service.files().get_media(fileId=file_id) | |
| file_content = io.BytesIO() | |
| downloader = MediaIoBaseDownload(file_content, request) | |
| done = False | |
| while done is False: | |
| status, done = downloader.next_chunk() | |
| file_content.seek(0) | |
| try: | |
| import PyPDF2 | |
| pdf_reader = PyPDF2.PdfReader(file_content) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except ImportError: | |
| logger.warning("PyPDF2 not available, trying alternative PDF extraction") | |
| # Try alternative PDF extraction | |
| try: | |
| import pdfplumber | |
| with pdfplumber.open(file_content) as pdf: | |
| text = "" | |
| for page in pdf.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| except ImportError: | |
| return "PDF text extraction requires PyPDF2 or pdfplumber library" | |
| except Exception as e: | |
| return f"Error extracting PDF text: {str(e)}" | |
| else: | |
| return "File type not supported for text extraction" | |
| except Exception as e: | |
| logger.error(f"Error reading file {file_id}: {e}") | |
| return f"Error reading file: {str(e)}" | |
| def process_documents_to_vector_store(self, files: List[Dict]) -> None: | |
| """Process documents and create/update vector store""" | |
| documents = [] | |
| new_files_processed = 0 | |
| for file in files: | |
| file_hash = self.get_file_hash(file['id'], file.get('size', '0')) | |
| # Check if file is already processed and hasn't changed | |
| if file_hash in self.processed_files: | |
| # Load cached documents | |
| cached_docs = self.processed_files[file_hash] | |
| documents.extend(cached_docs) | |
| continue | |
| # Process new or changed file | |
| content = self.get_file_content(file['id'], file['mimeType']) | |
| if content and not content.startswith('Error'): | |
| # Split content into chunks | |
| chunks = self.text_splitter.split_text(content) | |
| # Create Document objects with metadata | |
| file_documents = [] | |
| for i, chunk in enumerate(chunks): | |
| doc = Document( | |
| page_content=chunk, | |
| metadata={ | |
| 'source': file['name'], | |
| 'file_id': file['id'], | |
| 'chunk_id': i, | |
| 'mime_type': file['mimeType'], | |
| 'total_chunks': len(chunks) | |
| } | |
| ) | |
| file_documents.append(doc) | |
| documents.extend(file_documents) | |
| # Cache the processed documents | |
| self.processed_files[file_hash] = file_documents | |
| new_files_processed += 1 | |
| logger.info(f"Processed file: {file['name']} ({len(chunks)} chunks)") | |
| if new_files_processed > 0: | |
| self.save_cache() | |
| logger.info(f"Processed {new_files_processed} new files") | |
| # Create or update vector store | |
| if documents: | |
| if self.vector_store is None: | |
| self.vector_store = FAISS.from_documents(documents, self.embeddings) | |
| logger.info(f"Created new vector store with {len(documents)} documents") | |
| else: | |
| # Add new documents to existing vector store | |
| new_docs = [doc for file_docs in self.processed_files.values() | |
| for doc in file_docs if doc not in documents] | |
| if new_docs: | |
| self.vector_store.add_documents(new_docs) | |
| logger.info(f"Added {len(new_docs)} new documents to vector store") | |
| def create_conversational_chain(self) -> ConversationalRetrievalChain: | |
| """Create a conversational retrieval chain""" | |
| if self.vector_store is None: | |
| raise ValueError("Vector store not initialized. Process documents first.") | |
| # Create custom prompt template | |
| prompt_template = """You are Study Buddy, an AI assistant specialized in helping students study anatomy effectively. | |
| Use the following context from the student's study materials to answer their question. | |
| Context: {context} | |
| Question: {question} | |
| Instructions: | |
| 1. Answer the question directly and comprehensively using the provided context | |
| 2. If the context doesn't contain enough information, say so clearly | |
| 3. Provide study tips or exam strategies when relevant | |
| 4. Use clear, educational language appropriate for students | |
| 5. Always end your response with "Is there anything else I can help you with?" | |
| Answer:""" | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, | |
| input_variables=["context", "question"] | |
| ) | |
| # Create retrieval chain | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm=self.llm, | |
| retriever=self.vector_store.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": 6} # Retrieve top 6 relevant chunks | |
| ), | |
| memory=self.conversation_memory, | |
| combine_docs_chain_kwargs={"prompt": PROMPT}, | |
| return_source_documents=True, | |
| verbose=True | |
| ) | |
| return qa_chain | |
| def process_query(self, user_query: str, search_terms: Optional[List[str]] = None) -> Dict: | |
| """Enhanced query processing with LangChain""" | |
| try: | |
| # Extract search terms from query if not provided | |
| if not search_terms: | |
| search_terms = user_query.lower().split()[:5] # Take first 5 words | |
| # Search for relevant files | |
| all_files = [] | |
| for term in search_terms: | |
| files = self.search_files(term) | |
| all_files.extend(files) | |
| # Remove duplicates while preserving order | |
| unique_files = [] | |
| seen_ids = set() | |
| for file in all_files: | |
| if file['id'] not in seen_ids: | |
| unique_files.append(file) | |
| seen_ids.add(file['id']) | |
| if not unique_files: | |
| return { | |
| 'answer': "No relevant files found in your Google Drive for this query. Please check if you have uploaded study materials related to your question.", | |
| 'sources': [], | |
| 'confidence': 'low' | |
| } | |
| # Process documents and create vector store | |
| self.process_documents_to_vector_store(unique_files[:10]) # Process top 10 files | |
| if self.vector_store is None: | |
| return { | |
| 'answer': "Unable to process the documents. Please check if the files contain readable text content.", | |
| 'sources': [], | |
| 'confidence': 'low' | |
| } | |
| # Create conversational chain and get answer | |
| qa_chain = self.create_conversational_chain() | |
| # Query the chain | |
| result = qa_chain({"question": user_query}) | |
| # Extract source documents | |
| source_docs = result.get('source_documents', []) | |
| sources = list(set([doc.metadata['source'] for doc in source_docs])) | |
| # Calculate confidence based on source document relevance | |
| confidence = 'high' if len(source_docs) >= 3 else 'medium' if len(source_docs) >= 1 else 'low' | |
| return { | |
| 'answer': result['answer'], | |
| 'sources': sources, | |
| 'confidence': confidence, | |
| 'total_files_searched': len(unique_files), | |
| 'chunks_retrieved': len(source_docs) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}") | |
| return { | |
| 'answer': f"An error occurred while processing your query: {str(e)}. Please try again or rephrase your question.", | |
| 'sources': [], | |
| 'confidence': 'low' | |
| } | |
| def clear_memory(self): | |
| """Clear conversation memory""" | |
| self.conversation_memory.clear() | |
| logger.info("Conversation memory cleared") | |
| def get_vector_store_stats(self) -> Dict: | |
| """Get statistics about the vector store""" | |
| if self.vector_store is None: | |
| return {"total_documents": 0, "total_files": 0} | |
| try: | |
| total_docs = len(self.vector_store.docstore._dict) | |
| total_files = len(set([doc.metadata.get('source', 'Unknown') | |
| for doc in self.vector_store.docstore._dict.values()])) | |
| return { | |
| "total_documents": total_docs, | |
| "total_files": total_files, | |
| "cache_size": len(self.processed_files) | |
| } | |
| except: | |
| return {"total_documents": "Unknown", "total_files": "Unknown"} | |
| # Initialize the enhanced system | |
| enhanced_gpt_drive = EnhancedGPTDriveIntegration() | |
| def process_user_query(query: str, search_terms_input: str) -> tuple: | |
| """Process user query and return formatted response""" | |
| if not query.strip(): | |
| return "Please enter a question.", "", "" | |
| # Parse search terms if provided | |
| search_terms = None | |
| if search_terms_input.strip(): | |
| search_terms = [term.strip() for term in search_terms_input.split(',')] | |
| # Process the query | |
| result = enhanced_gpt_drive.process_query(query, search_terms) | |
| # Format the response | |
| answer = result['answer'] | |
| sources = result['sources'] | |
| # Create detailed sources text | |
| sources_text = "" | |
| if sources: | |
| sources_text = "**Sources used:**\n" + "\n".join([f"β’ {source}" for source in sources]) | |
| sources_text += f"\n\n**Search Details:**\n" | |
| sources_text += f"β’ Files searched: {result.get('total_files_searched', 0)}\n" | |
| sources_text += f"β’ Relevant chunks found: {result.get('chunks_retrieved', 0)}\n" | |
| sources_text += f"β’ Confidence: {result.get('confidence', 'unknown').title()}" | |
| # Stats for display | |
| stats = enhanced_gpt_drive.get_vector_store_stats() | |
| stats_text = f"**Knowledge Base:** {stats['total_documents']} chunks from {stats['total_files']} files" | |
| return answer, sources_text, stats_text | |
| def clear_conversation(): | |
| """Clear conversation memory""" | |
| enhanced_gpt_drive.clear_memory() | |
| return "Conversation history cleared. You can start a fresh conversation now." | |
| def get_system_status(): | |
| """Get system status information""" | |
| stats = enhanced_gpt_drive.get_vector_store_stats() | |
| status_lines = [ | |
| "β Google Drive API: Connected", | |
| "β OpenAI API: Connected", | |
| "β LangChain: Initialized", | |
| f"π Knowledge Base: {stats['total_documents']} document chunks", | |
| f"π Processed Files: {stats['total_files']} files", | |
| f"πΎ Cache Size: {stats['cache_size']} entries" | |
| ] | |
| return "\n".join(status_lines) | |
| # Create enhanced Gradio interface | |
| import gradio as gr | |
| with gr.Blocks(title="Enhanced Study Buddy", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# π§ Enhanced Anatomy Study Buddy with LangChain") | |
| gr.Markdown("Study more effectively with advanced AI-powered document analysis and conversational memory!") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Main query interface | |
| with gr.Group(): | |
| gr.Markdown("### π¬ Ask a Question") | |
| query_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask me anything about your anatomy study materials...", | |
| lines=3 | |
| ) | |
| search_terms_input = gr.Textbox( | |
| label="π Search Terms (Optional)", | |
| placeholder="Enter comma-separated terms to focus the search", | |
| lines=1 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Search & Ask", variant="primary", size="lg") | |
| clear_btn = gr.Button("π§Ή Clear Memory", variant="secondary") | |
| # Results section | |
| with gr.Group(): | |
| gr.Markdown("### π― Answer") | |
| answer_output = gr.Textbox( | |
| label="AI Response", | |
| lines=12, | |
| interactive=False | |
| ) | |
| sources_output = gr.Textbox( | |
| label="π Sources & Details", | |
| lines=6, | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| # System info | |
| with gr.Group(): | |
| gr.Markdown("### π System Status") | |
| status_btn = gr.Button("π Refresh Status", size="sm") | |
| status_output = gr.Textbox( | |
| label="System Information", | |
| lines=8, | |
| interactive=False | |
| ) | |
| stats_output = gr.Textbox( | |
| label="Knowledge Base", | |
| lines=2, | |
| interactive=False | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=process_user_query, | |
| inputs=[query_input, search_terms_input], | |
| outputs=[answer_output, sources_output, stats_output] | |
| ) | |
| clear_btn.click( | |
| fn=clear_conversation, | |
| outputs=answer_output | |
| ) | |
| status_btn.click( | |
| fn=get_system_status, | |
| outputs=status_output | |
| ) | |
| # Enhanced examples | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| ["What is morbid anatomy and how does it relate to pathology?", "morbid, anatomy, pathology"], | |
| ["Explain the neural transmission process between neurons", "neuron, transmission, synaptic"], | |
| ["Describe the complete anatomy of the external ear", "external ear, anatomy, auditory"], | |
| ["What are the different types of therapeutic massage?", "massage, therapy, treatment"], | |
| ["Define trauma and its classification in medical terms", "trauma, medical, classification"], | |
| ["Explain upper limb prosthetics and their applications", "prosthetics, upper limb, rehabilitation"], | |
| ["How does the nervous system control muscle movement?", "nervous system, muscle, motor control"], | |
| ["What are the key anatomical landmarks for injection sites?", "injection sites, anatomical landmarks"] | |
| ], | |
| inputs=[query_input, search_terms_input] | |
| ) | |
| # Initial status load | |
| app.load( | |
| fn=get_system_status, | |
| outputs=status_output | |
| ) | |
| # Launch the enhanced app | |
| if __name__ == "__main__": | |
| app.launch( | |
| share=True, | |
| debug=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |