Spaces:
Sleeping
Sleeping
File size: 7,280 Bytes
35dae13 b37b83c 35dae13 b37b83c 35dae13 6ba2b91 35dae13 b37b83c 35dae13 b37b83c 35dae13 b37b83c 0a80497 35dae13 41e31c7 35dae13 94d2d89 35dae13 0a80497 35dae13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | import gradio as gr
from huggingface_hub import InferenceClient
import os
from populate_db import main # Import the main function from populate_db.py
# Embeddings - with fallback for older versions
try:
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
except ImportError:
# Fallback to older imports
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain_community.llms import Ollama
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
PROMPT_TEMPLATE = """You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources. Answer the question based only on the following context:
{context}
----
Answer the question based on the above context: {question}
If the context does not contain enough information to answer the question, say "I don't know". Do not make up an answer.
"""
DEFAULT_SYSTEM_MESSAGE = "You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources."
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
def get_embedding_function():
embedddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embedddings
client = InferenceClient(provider="nebius", model="meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.getenv("ACCESS_TOKEN"))
def query_rag(query: str, top_k: int = 5):
"""
Query the RAG system with a given query string and return the top_k results.
"""
try:
# Initialize the vector store
vector_store = Chroma(
embedding_function=get_embedding_function(),
persist_directory="chroma_db",
)
results = vector_store.similarity_search_with_score(query, k=top_k)
if not results:
return "I don't know - no relevant documents found."
context_texts = "\n\n --- \n\n".join([document.page_content for document, _score in results])
prompt_template = PromptTemplate.from_template(PROMPT_TEMPLATE)
prompt = prompt_template.format(context=context_texts, question=query)
# Use the Ollama model if running locally
try:
model = Ollama(model="llama2")
response_text = model.invoke(prompt)
except Exception as ollama_error:
print(f"Ollama error: {ollama_error}")
# Fallback to HuggingFace client
response_text = fallback_to_hf_client(prompt)
sources = [doc.metadata.get("id", "Unknown") for doc, _score in results]
# Clean up source names for better display
clean_sources = []
for source in sources:
if source and source != "Unknown":
# Extract filename from the source metadata
# Format is typically: "path/to/file:page:chunk"
try:
file_part = source.split(":")[0] # Get the file path part
filename = os.path.basename(file_part) # Extract just the filename
if filename:
clean_sources.append(filename)
except (IndexError, AttributeError, ValueError):
clean_sources.append(source) # Fallback to original if parsing fails
# Format the final response with sources
if clean_sources:
unique_sources = list(set(clean_sources)) # Remove duplicates
formatted_response = f"{response_text}\n\n**๐ Sources:**\n{chr(10).join([f'โข {source}' for source in unique_sources])}"
else:
formatted_response = f"{response_text}\n\n*Note: Sources information not available*"
print(f"Formatted response: {formatted_response}")
return formatted_response
except Exception as e:
print(f"Error in query_rag: {e}")
return f"I encountered an error while processing your query: {str(e)}"
def fallback_to_hf_client(prompt: str):
"""Fallback to HuggingFace client when Ollama is not available"""
try:
messages = [{"role": "user", "content": prompt}]
response = ""
for message in client.chat_completion(
messages,
max_tokens=512,
stream=True,
temperature=0.7,
top_p=0.95,
):
token = message.choices[0].delta.content
if token:
response += token
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Use RAG for document-based queries
try:
rag_response = query_rag(message)
# If RAG finds relevant information, return it
if rag_response and not rag_response.startswith("I don't know") and not rag_response.startswith("I encountered an error"):
yield rag_response
return
except Exception as e:
print(f"RAG query failed: {e}")
# Fallback to regular chat completion
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
try:
for message_chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message_chunk.choices[0].delta.content
if token:
response += token
yield response
except Exception as e:
yield f"Error: {str(e)}"
demo = gr.ChatInterface(
respond,
title="๐ CS Query - RAG-Powered Academic Assistant",
description="Ask questions about competence standards and get answers based on the uploaded academic documents.",
chatbot=gr.Chatbot(height=500),
examples=[
[
"What are reasonable adjustments for students with disabilities?",
DEFAULT_SYSTEM_MESSAGE,
512,
0.7,
0.95
],
[
"What does the Equality Act say about education?",
DEFAULT_SYSTEM_MESSAGE,
512,
0.7,
0.95
]
],
)
if __name__ == "__main__":
# main()
demo.launch(
inbrowser=True, # Open in browser automatically
height=800, # Increase overall height
width="100%", # Use full width
)
|