Spaces:
Sleeping
Sleeping
| # deploy_rag_local.py (Gemini version with chat history) | |
| import os | |
| import gradio as gr | |
| import google.generativeai as genai | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.llms.base import LLM | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableLambda | |
| from typing import ClassVar | |
| import unittest | |
| # Constants | |
| EMBEDDING_DIR = "local_embedding" | |
| CHROMA_DIR = "chroma_store" | |
| STATIC_CONTEXT_HEADER = "Use only the academic documents procided . Do not use external knowledge." | |
| # Configure Gemini API | |
| GENAI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GENAI_API_KEY: | |
| raise ValueError("Missing GEMINI_API_KEY in environment variables.") | |
| genai.configure(api_key=GENAI_API_KEY) | |
| # Custom Gemini LLM wrapper | |
| class GeminiLLM(LLM): | |
| model: ClassVar[str] = "models/gemini-2.0-flash" | |
| def _call(self, prompt: str, stop=None): | |
| try: | |
| response = genai.GenerativeModel(self.model).generate_content(prompt) | |
| return response.text | |
| except Exception as e: | |
| return f"[Gemini API error] {str(e)}" | |
| def _llm_type(self) -> str: | |
| return "gemini" | |
| # Load embedding model | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_DIR) | |
| # Load vector DB | |
| db = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings) | |
| retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| # Prompt | |
| prompt_template = """ | |
| You are a robust AI assistant trained only to answer based on provided context. | |
| You must quote or directly cite the context in your answer. | |
| If the answer cannot be found in the context or quoted, respond with: | |
| \"Sorry, I can only answer questions related to the AI material provided.\" | |
| Context: | |
| {context} | |
| User: | |
| {question} | |
| Assistant: | |
| """ | |
| prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template) | |
| llm = GeminiLLM() | |
| rag_chain = RunnableLambda(lambda q: safe_context_retrieval(q)) | (prompt | llm | StrOutputParser()) | |
| def safe_context_retrieval(question): | |
| docs = retriever.invoke(question) | |
| if not docs: | |
| return {"context": "Sorry, I can't answer that based on the provided documents.", "question": question} | |
| context_body = "\n".join([doc.page_content for doc in docs]) | |
| if len(context_body.strip()) < 100: | |
| return {"context": "Sorry, I can't answer that based on the provided documents.", "question": question} | |
| full_context = f"{STATIC_CONTEXT_HEADER}\n\n{context_body}" | |
| return {"context": full_context, "question": question} | |
| def generate_answer_with_history(question, history): | |
| response = rag_chain.invoke(question) | |
| history.append((question, response.strip())) | |
| return history, history | |
| # Gradio UI with history | |
| demo = gr.Interface( | |
| fn=generate_answer_with_history, | |
| inputs=[gr.Textbox(label="Ask a Question"), gr.State([])], | |
| outputs=[gr.State(), gr.Chatbot(label="Chat History")], | |
| title="StudyBuddy" | |
| ) | |
| # Unit test | |
| class TestGeminiRAG(unittest.TestCase): | |
| def test_valid_question(self): | |
| result, _ = generate_answer_with_history("What is artificial intelligence?", []) | |
| self.assertIsInstance(result, list) | |
| self.assertTrue(len(result) > 0) | |
| self.assertTrue(len(result[0][1]) > 10) | |
| if __name__ == "__main__": | |
| demo.launch() | |