Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import google.generativeai as genai | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| import warnings | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| # Configuration - PUT YOUR API KEY HERE | |
| GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ REPLACE WITH YOUR KEY | |
| MODEL_NAME = "all-MiniLM-L6-v2" | |
| GENAI_MODEL = "gemini-pro" | |
| DATASET_NAME = "midrees2806/7K_Dataset" | |
| CHUNK_SIZE = 500 | |
| TOP_K = 3 | |
| # Initialize Gemini with enhanced configuration | |
| genai.configure( | |
| api_key=GEMINI_API_KEY, | |
| transport='rest', # Force REST API | |
| client_options={ | |
| 'api_endpoint': "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" | |
| } | |
| ) | |
| class GeminiRAGSystem: | |
| def __init__(self): | |
| self.index = None | |
| self.chunks = [] | |
| self.dataset_loaded = False | |
| self.loading_error = None | |
| print("Initializing embedding model...") | |
| try: | |
| self.embedding_model = SentenceTransformer(MODEL_NAME) | |
| print("Embedding model initialized successfully") | |
| except Exception as e: | |
| error_msg = f"Failed to initialize embedding model: {str(e)}" | |
| print(error_msg) | |
| raise RuntimeError(error_msg) | |
| print("Loading dataset...") | |
| self.load_dataset() | |
| def load_dataset(self): | |
| """Load dataset with detailed error handling""" | |
| try: | |
| print(f"Downloading dataset: {DATASET_NAME}") | |
| dataset = load_dataset( | |
| DATASET_NAME, | |
| split='train', | |
| download_mode="force_redownload" | |
| ) | |
| print("Dataset downloaded successfully") | |
| if 'text' in dataset.features: | |
| self.chunks = dataset['text'][:1000] | |
| print(f"Loaded {len(self.chunks)} text chunks") | |
| elif 'context' in dataset.features: | |
| self.chunks = dataset['context'][:1000] | |
| print(f"Loaded {len(self.chunks)} context chunks") | |
| else: | |
| raise ValueError("Dataset must have 'text' or 'context' field") | |
| print("Creating embeddings...") | |
| embeddings = self.embedding_model.encode( | |
| self.chunks, | |
| show_progress_bar=False, | |
| convert_to_numpy=True | |
| ) | |
| print(f"Created embeddings with shape {embeddings.shape}") | |
| self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| self.index.add(embeddings.astype('float32')) | |
| print("FAISS index created successfully") | |
| self.dataset_loaded = True | |
| print("Dataset loading complete") | |
| except Exception as e: | |
| error_msg = f"Dataset loading failed: {str(e)}" | |
| print(error_msg) | |
| self.loading_error = error_msg | |
| def get_relevant_context(self, query: str) -> str: | |
| """Retrieve context with debugging""" | |
| if not self.index: | |
| print("No index available for search") | |
| return "" | |
| try: | |
| print(f"Processing query: {query}") | |
| query_embed = self.embedding_model.encode( | |
| [query], | |
| convert_to_numpy=True | |
| ).astype('float32') | |
| print("Query embedded successfully") | |
| distances, indices = self.index.search(query_embed, k=TOP_K) | |
| print(f"Search results - distances: {distances}, indices: {indices}") | |
| context = "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)]) | |
| print(f"Context length: {len(context)} characters") | |
| return context | |
| except Exception as e: | |
| print(f"Search error: {str(e)}") | |
| return "" | |
| def generate_response(self, query: str) -> str: | |
| """Generate response with detailed error handling""" | |
| if not self.dataset_loaded: | |
| msg = f"⚠️ Dataset loading failed: {self.loading_error}" if self.loading_error else "⚠️ System initializing..." | |
| print(msg) | |
| return msg | |
| print(f"\n{'='*40}\nNew Query: {query}\n{'='*40}") | |
| context = self.get_relevant_context(query) | |
| if not context: | |
| print("No relevant context found") | |
| return "No relevant context found" | |
| prompt = f"""Answer based on this context: | |
| {context} | |
| Question: {query} | |
| Answer concisely:""" | |
| print(f"\nPrompt sent to Gemini:\n{prompt}\n") | |
| try: | |
| model = genai.GenerativeModel(GENAI_MODEL) | |
| response = model.generate_content( | |
| prompt, | |
| generation_config=genai.types.GenerationConfig( | |
| temperature=0.3, | |
| max_output_tokens=1000 | |
| ) | |
| ) | |
| print(f"Raw API response: {response}") | |
| if response.candidates and response.candidates[0].content.parts: | |
| answer = response.candidates[0].content.parts[0].text | |
| print(f"Answer: {answer}") | |
| return answer | |
| print("⚠️ Empty response from API") | |
| return "⚠️ No response from API" | |
| except Exception as e: | |
| error_msg = f"⚠️ API Error: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| # Initialize system with verbose logging | |
| print("Initializing RAG system...") | |
| try: | |
| rag_system = GeminiRAGSystem() | |
| init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}" | |
| print(init_status) | |
| except Exception as e: | |
| init_status = f"❌ Initialization failed: {str(e)}" | |
| print(init_status) | |
| rag_system = None | |
| # Create interface with enhanced debugging | |
| with gr.Blocks(title="Document Chatbot") as app: | |
| gr.Markdown("# Document Chatbot with Gemini") | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(height=500, label="Chat History") | |
| with gr.Row(): | |
| query = gr.Textbox(label="Your question", placeholder="Ask about the documents...") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| status = gr.Textbox(label="System Status", value=init_status, interactive=False) | |
| def respond(message, chat_history): | |
| print(f"\n{'='*40}\nUser Query: {message}\n{'='*40}") | |
| if not rag_system: | |
| error_msg = "System initialization failed" | |
| print(error_msg) | |
| return chat_history + [(message, error_msg)] | |
| response = rag_system.generate_response(message) | |
| return chat_history + [(message, response)] | |
| def clear_chat(): | |
| print("Chat cleared") | |
| return [] | |
| submit_btn.click(respond, [query, chatbot], [chatbot]) | |
| query.submit(respond, [query, chatbot], [chatbot]) | |
| clear_btn.click(clear_chat, outputs=chatbot) | |
| if __name__ == "__main__": | |
| print("Launching Gradio interface...") | |
| app.launch(debug=True) |