Agent_CS / app.py
daniel-was-taken's picture
Update app.py
6ba2b91 verified
import gradio as gr
from huggingface_hub import InferenceClient
import os
from populate_db import main # Import the main function from populate_db.py
# Embeddings - with fallback for older versions
try:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
except ImportError:
# Fallback to older imports
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain_community.llms import Ollama
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
PROMPT_TEMPLATE = """You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources. Answer the question based only on the following context:
{context}
----
Answer the question based on the above context: {question}
If the context does not contain enough information to answer the question, say "I don't know". Do not make up an answer.
"""
DEFAULT_SYSTEM_MESSAGE = "You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources."
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
def get_embedding_function():
embedddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embedddings
client = InferenceClient(provider="nebius", model="meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.getenv("ACCESS_TOKEN"))
def query_rag(query: str, top_k: int = 5):
"""
Query the RAG system with a given query string and return the top_k results.
"""
try:
# Initialize the vector store
vector_store = Chroma(
embedding_function=get_embedding_function(),
persist_directory="chroma_db",
)
results = vector_store.similarity_search_with_score(query, k=top_k)
if not results:
return "I don't know - no relevant documents found."
context_texts = "\n\n --- \n\n".join([document.page_content for document, _score in results])
prompt_template = PromptTemplate.from_template(PROMPT_TEMPLATE)
prompt = prompt_template.format(context=context_texts, question=query)
# Use the Ollama model if running locally
try:
model = Ollama(model="llama2")
response_text = model.invoke(prompt)
except Exception as ollama_error:
print(f"Ollama error: {ollama_error}")
# Fallback to HuggingFace client
response_text = fallback_to_hf_client(prompt)
sources = [doc.metadata.get("id", "Unknown") for doc, _score in results]
# Clean up source names for better display
clean_sources = []
for source in sources:
if source and source != "Unknown":
# Extract filename from the source metadata
# Format is typically: "path/to/file:page:chunk"
try:
file_part = source.split(":")[0] # Get the file path part
filename = os.path.basename(file_part) # Extract just the filename
if filename:
clean_sources.append(filename)
except (IndexError, AttributeError, ValueError):
clean_sources.append(source) # Fallback to original if parsing fails
# Format the final response with sources
if clean_sources:
unique_sources = list(set(clean_sources)) # Remove duplicates
formatted_response = f"{response_text}\n\n**๐Ÿ“š Sources:**\n{chr(10).join([f'โ€ข {source}' for source in unique_sources])}"
else:
formatted_response = f"{response_text}\n\n*Note: Sources information not available*"
print(f"Formatted response: {formatted_response}")
return formatted_response
except Exception as e:
print(f"Error in query_rag: {e}")
return f"I encountered an error while processing your query: {str(e)}"
def fallback_to_hf_client(prompt: str):
"""Fallback to HuggingFace client when Ollama is not available"""
try:
messages = [{"role": "user", "content": prompt}]
response = ""
for message in client.chat_completion(
messages,
max_tokens=512,
stream=True,
temperature=0.7,
top_p=0.95,
):
token = message.choices[0].delta.content
if token:
response += token
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Use RAG for document-based queries
try:
rag_response = query_rag(message)
# If RAG finds relevant information, return it
if rag_response and not rag_response.startswith("I don't know") and not rag_response.startswith("I encountered an error"):
yield rag_response
return
except Exception as e:
print(f"RAG query failed: {e}")
# Fallback to regular chat completion
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
try:
for message_chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message_chunk.choices[0].delta.content
if token:
response += token
yield response
except Exception as e:
yield f"Error: {str(e)}"
demo = gr.ChatInterface(
respond,
title="๐ŸŽ“ CS Query - RAG-Powered Academic Assistant",
description="Ask questions about competence standards and get answers based on the uploaded academic documents.",
chatbot=gr.Chatbot(height=500),
examples=[
[
"What are reasonable adjustments for students with disabilities?",
DEFAULT_SYSTEM_MESSAGE,
512,
0.7,
0.95
],
[
"What does the Equality Act say about education?",
DEFAULT_SYSTEM_MESSAGE,
512,
0.7,
0.95
]
],
)
if __name__ == "__main__":
# main()
demo.launch(
inbrowser=True, # Open in browser automatically
height=800, # Increase overall height
width="100%", # Use full width
)