Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| import os | |
| from populate_db import main # Import the main function from populate_db.py | |
| # Embeddings - with fallback for older versions | |
| try: | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_chroma import Chroma | |
| except ImportError: | |
| # Fallback to older imports | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.llms import Ollama | |
| """ | |
| For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
| """ | |
| PROMPT_TEMPLATE = """You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources. Answer the question based only on the following context: | |
| {context} | |
| ---- | |
| Answer the question based on the above context: {question} | |
| If the context does not contain enough information to answer the question, say "I don't know". Do not make up an answer. | |
| """ | |
| DEFAULT_SYSTEM_MESSAGE = "You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources." | |
| model_name = "sentence-transformers/all-mpnet-base-v2" | |
| model_kwargs = {'device': 'cpu'} | |
| encode_kwargs = {'normalize_embeddings': False} | |
| def get_embedding_function(): | |
| embedddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs, | |
| ) | |
| return embedddings | |
| client = InferenceClient(provider="nebius", model="meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.getenv("ACCESS_TOKEN")) | |
| def query_rag(query: str, top_k: int = 5): | |
| """ | |
| Query the RAG system with a given query string and return the top_k results. | |
| """ | |
| try: | |
| # Initialize the vector store | |
| vector_store = Chroma( | |
| embedding_function=get_embedding_function(), | |
| persist_directory="chroma_db", | |
| ) | |
| results = vector_store.similarity_search_with_score(query, k=top_k) | |
| if not results: | |
| return "I don't know - no relevant documents found." | |
| context_texts = "\n\n --- \n\n".join([document.page_content for document, _score in results]) | |
| prompt_template = PromptTemplate.from_template(PROMPT_TEMPLATE) | |
| prompt = prompt_template.format(context=context_texts, question=query) | |
| # Use the Ollama model if running locally | |
| try: | |
| model = Ollama(model="llama2") | |
| response_text = model.invoke(prompt) | |
| except Exception as ollama_error: | |
| print(f"Ollama error: {ollama_error}") | |
| # Fallback to HuggingFace client | |
| response_text = fallback_to_hf_client(prompt) | |
| sources = [doc.metadata.get("id", "Unknown") for doc, _score in results] | |
| # Clean up source names for better display | |
| clean_sources = [] | |
| for source in sources: | |
| if source and source != "Unknown": | |
| # Extract filename from the source metadata | |
| # Format is typically: "path/to/file:page:chunk" | |
| try: | |
| file_part = source.split(":")[0] # Get the file path part | |
| filename = os.path.basename(file_part) # Extract just the filename | |
| if filename: | |
| clean_sources.append(filename) | |
| except (IndexError, AttributeError, ValueError): | |
| clean_sources.append(source) # Fallback to original if parsing fails | |
| # Format the final response with sources | |
| if clean_sources: | |
| unique_sources = list(set(clean_sources)) # Remove duplicates | |
| formatted_response = f"{response_text}\n\n**๐ Sources:**\n{chr(10).join([f'โข {source}' for source in unique_sources])}" | |
| else: | |
| formatted_response = f"{response_text}\n\n*Note: Sources information not available*" | |
| print(f"Formatted response: {formatted_response}") | |
| return formatted_response | |
| except Exception as e: | |
| print(f"Error in query_rag: {e}") | |
| return f"I encountered an error while processing your query: {str(e)}" | |
| def fallback_to_hf_client(prompt: str): | |
| """Fallback to HuggingFace client when Ollama is not available""" | |
| try: | |
| messages = [{"role": "user", "content": prompt}] | |
| response = "" | |
| for message in client.chat_completion( | |
| messages, | |
| max_tokens=512, | |
| stream=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| ): | |
| token = message.choices[0].delta.content | |
| if token: | |
| response += token | |
| return response | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| # Use RAG for document-based queries | |
| try: | |
| rag_response = query_rag(message) | |
| # If RAG finds relevant information, return it | |
| if rag_response and not rag_response.startswith("I don't know") and not rag_response.startswith("I encountered an error"): | |
| yield rag_response | |
| return | |
| except Exception as e: | |
| print(f"RAG query failed: {e}") | |
| # Fallback to regular chat completion | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| response = "" | |
| try: | |
| for message_chunk in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| token = message_chunk.choices[0].delta.content | |
| if token: | |
| response += token | |
| yield response | |
| except Exception as e: | |
| yield f"Error: {str(e)}" | |
| demo = gr.ChatInterface( | |
| respond, | |
| title="๐ CS Query - RAG-Powered Academic Assistant", | |
| description="Ask questions about competence standards and get answers based on the uploaded academic documents.", | |
| chatbot=gr.Chatbot(height=500), | |
| examples=[ | |
| [ | |
| "What are reasonable adjustments for students with disabilities?", | |
| DEFAULT_SYSTEM_MESSAGE, | |
| 512, | |
| 0.7, | |
| 0.95 | |
| ], | |
| [ | |
| "What does the Equality Act say about education?", | |
| DEFAULT_SYSTEM_MESSAGE, | |
| 512, | |
| 0.7, | |
| 0.95 | |
| ] | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| # main() | |
| demo.launch( | |
| inbrowser=True, # Open in browser automatically | |
| height=800, # Increase overall height | |
| width="100%", # Use full width | |
| ) | |