StudyBuddy / app.py
Diaa-Zaher's picture
Upload folder using huggingface_hub
32c3961 verified
# deploy_rag_local.py (Gemini version with chat history)
import os
import gradio as gr
import google.generativeai as genai
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from typing import ClassVar
import unittest
# Constants
EMBEDDING_DIR = "local_embedding"
CHROMA_DIR = "chroma_store"
STATIC_CONTEXT_HEADER = "Use only the academic documents procided . Do not use external knowledge."
# Configure Gemini API
GENAI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GENAI_API_KEY:
raise ValueError("Missing GEMINI_API_KEY in environment variables.")
genai.configure(api_key=GENAI_API_KEY)
# Custom Gemini LLM wrapper
class GeminiLLM(LLM):
model: ClassVar[str] = "models/gemini-2.0-flash"
def _call(self, prompt: str, stop=None):
try:
response = genai.GenerativeModel(self.model).generate_content(prompt)
return response.text
except Exception as e:
return f"[Gemini API error] {str(e)}"
@property
def _llm_type(self) -> str:
return "gemini"
# Load embedding model
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_DIR)
# Load vector DB
db = Chroma(persist_directory=CHROMA_DIR, embedding_function=embeddings)
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
# Prompt
prompt_template = """
You are a robust AI assistant trained only to answer based on provided context.
You must quote or directly cite the context in your answer.
If the answer cannot be found in the context or quoted, respond with:
\"Sorry, I can only answer questions related to the AI material provided.\"
Context:
{context}
User:
{question}
Assistant:
"""
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)
llm = GeminiLLM()
rag_chain = RunnableLambda(lambda q: safe_context_retrieval(q)) | (prompt | llm | StrOutputParser())
def safe_context_retrieval(question):
docs = retriever.invoke(question)
if not docs:
return {"context": "Sorry, I can't answer that based on the provided documents.", "question": question}
context_body = "\n".join([doc.page_content for doc in docs])
if len(context_body.strip()) < 100:
return {"context": "Sorry, I can't answer that based on the provided documents.", "question": question}
full_context = f"{STATIC_CONTEXT_HEADER}\n\n{context_body}"
return {"context": full_context, "question": question}
def generate_answer_with_history(question, history):
response = rag_chain.invoke(question)
history.append((question, response.strip()))
return history, history
# Gradio UI with history
demo = gr.Interface(
fn=generate_answer_with_history,
inputs=[gr.Textbox(label="Ask a Question"), gr.State([])],
outputs=[gr.State(), gr.Chatbot(label="Chat History")],
title="StudyBuddy"
)
# Unit test
class TestGeminiRAG(unittest.TestCase):
def test_valid_question(self):
result, _ = generate_answer_with_history("What is artificial intelligence?", [])
self.assertIsInstance(result, list)
self.assertTrue(len(result) > 0)
self.assertTrue(len(result[0][1]) > 10)
if __name__ == "__main__":
demo.launch()