Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import google.generativeai as genai | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| import warnings | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| # Configuration | |
| MODEL_NAME = "all-MiniLM-L6-v2" | |
| GENAI_MODEL = "models/gemini-pro" # Updated model path | |
| DATASET_NAME = "midrees2806/7K_Dataset" | |
| CHUNK_SIZE = 500 | |
| TOP_K = 3 | |
| # Initialize Gemini - PUT YOUR API KEY HERE (for testing only) | |
| GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ Replace with your actual key | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| class GeminiRAGSystem: | |
| def __init__(self): | |
| self.index = None | |
| self.chunks = [] | |
| self.dataset_loaded = False | |
| self.loading_error = None | |
| # Initialize embedding model | |
| try: | |
| self.embedding_model = SentenceTransformer(MODEL_NAME) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to initialize embedding model: {str(e)}") | |
| # Load dataset | |
| self.load_dataset() | |
| def load_dataset(self): | |
| """Load dataset synchronously""" | |
| try: | |
| dataset = load_dataset( | |
| DATASET_NAME, | |
| split='train', | |
| download_mode="force_redownload" | |
| ) | |
| if 'text' in dataset.features: | |
| self.chunks = dataset['text'][:1000] | |
| elif 'context' in dataset.features: | |
| self.chunks = dataset['context'][:1000] | |
| else: | |
| raise ValueError("Dataset must have 'text' or 'context' field") | |
| embeddings = self.embedding_model.encode( | |
| self.chunks, | |
| show_progress_bar=False, | |
| convert_to_numpy=True | |
| ) | |
| self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| self.index.add(embeddings.astype('float32')) | |
| self.dataset_loaded = True | |
| except Exception as e: | |
| self.loading_error = str(e) | |
| print(f"Dataset loading failed: {str(e)}") | |
| def get_relevant_context(self, query: str) -> str: | |
| """Retrieve most relevant chunks""" | |
| if not self.index: | |
| return "" | |
| try: | |
| query_embed = self.embedding_model.encode( | |
| [query], | |
| convert_to_numpy=True | |
| ).astype('float32') | |
| _, indices = self.index.search(query_embed, k=TOP_K) | |
| return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)]) | |
| except Exception as e: | |
| print(f"Search error: {str(e)}") | |
| return "" | |
| def generate_response(self, query: str) -> str: | |
| """Generate response with robust error handling""" | |
| if not self.dataset_loaded: | |
| if self.loading_error: | |
| return f"⚠️ Dataset loading failed: {self.loading_error}" | |
| return "⚠️ System initializing..." | |
| context = self.get_relevant_context(query) | |
| if not context: | |
| return "No relevant context found" | |
| prompt = f"""Answer based on this context: | |
| {context} | |
| Question: {query} | |
| Answer concisely:""" | |
| try: | |
| model = genai.GenerativeModel(GENAI_MODEL) | |
| response = model.generate_content(prompt) | |
| return response.text | |
| except Exception as e: | |
| return f"⚠️ API Error: {str(e)}" | |
| # Initialize system | |
| try: | |
| rag_system = GeminiRAGSystem() | |
| init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}" | |
| except Exception as e: | |
| init_status = f"❌ Initialization failed: {str(e)}" | |
| rag_system = None | |
| # Create interface | |
| with gr.Blocks(title="Chatbot") as app: | |
| gr.Markdown("# Chatbot") | |
| chatbot = gr.Chatbot(height=500) | |
| query = gr.Textbox(label="Your question", placeholder="Ask something...") | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.Button("Clear") | |
| status = gr.Textbox(label="Status", value=init_status) | |
| def respond(message, chat_history): | |
| if not rag_system: | |
| return chat_history + [(message, "System initialization failed")] | |
| response = rag_system.generate_response(message) | |
| return chat_history + [(message, response)] | |
| def clear_chat(): | |
| return [] | |
| submit_btn.click(respond, [query, chatbot], [chatbot]) | |
| query.submit(respond, [query, chatbot], [chatbot]) | |
| clear_btn.click(clear_chat, outputs=chatbot) | |
| if __name__ == "__main__": | |
| app.launch(share=True) |