File size: 7,280 Bytes
35dae13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b37b83c
35dae13
 
 
 
 
 
 
 
 
b37b83c
35dae13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ba2b91
35dae13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b37b83c
35dae13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b37b83c
35dae13
 
 
b37b83c
0a80497
35dae13
 
 
 
 
 
 
 
 
 
 
 
 
 
41e31c7
35dae13
 
 
 
 
94d2d89
35dae13
 
0a80497
 
35dae13
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import gradio as gr
from huggingface_hub import InferenceClient
import os
from populate_db import main  # Import the main function from populate_db.py
# Embeddings - with fallback for older versions
try:
    from langchain_huggingface import HuggingFaceEmbeddings
    from langchain_chroma import Chroma
except ImportError:
    
    # Fallback to older imports
    from langchain_community.embeddings import HuggingFaceEmbeddings
    from langchain_community.vectorstores import Chroma
    
from langchain.prompts import PromptTemplate
from langchain_community.llms import Ollama    

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

PROMPT_TEMPLATE = """You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources. Answer the question based only on the following context: 
{context}

----

Answer the question based on the above context: {question}

If the context does not contain enough information to answer the question, say "I don't know". Do not make up an answer. 
"""

DEFAULT_SYSTEM_MESSAGE = "You are a helpful academic assistant specialised in competence standard and disability support in higher education. Use the provided documents to answer questions accurately and cite your sources."


model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}

def get_embedding_function():
    embedddings = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
    )
    return embedddings


client = InferenceClient(provider="nebius", model="meta-llama/Meta-Llama-3.1-8B-Instruct", token=os.getenv("ACCESS_TOKEN"))

def query_rag(query: str, top_k: int = 5):
    """
    Query the RAG system with a given query string and return the top_k results.
    """
    try:
        # Initialize the vector store
        vector_store = Chroma(
            embedding_function=get_embedding_function(),
            persist_directory="chroma_db",
        )

        results = vector_store.similarity_search_with_score(query, k=top_k)
        
        if not results:
            return "I don't know - no relevant documents found."

        context_texts = "\n\n --- \n\n".join([document.page_content for document, _score in results])

        prompt_template = PromptTemplate.from_template(PROMPT_TEMPLATE)
        prompt = prompt_template.format(context=context_texts, question=query)

        # Use the Ollama model if running locally
        try:
            model = Ollama(model="llama2")
            response_text = model.invoke(prompt)
        except Exception as ollama_error:
            print(f"Ollama error: {ollama_error}")
            # Fallback to HuggingFace client
            response_text = fallback_to_hf_client(prompt)

        sources = [doc.metadata.get("id", "Unknown") for doc, _score in results]
        # Clean up source names for better display
        clean_sources = []
        for source in sources:
            if source and source != "Unknown":
                # Extract filename from the source metadata
                # Format is typically: "path/to/file:page:chunk"
                try:
                    file_part = source.split(":")[0]  # Get the file path part
                    filename = os.path.basename(file_part)  # Extract just the filename
                    if filename:
                        clean_sources.append(filename)
                except (IndexError, AttributeError, ValueError):
                    clean_sources.append(source)  # Fallback to original if parsing fails
        
        # Format the final response with sources
        if clean_sources:
            unique_sources = list(set(clean_sources))  # Remove duplicates
            formatted_response = f"{response_text}\n\n**๐Ÿ“š Sources:**\n{chr(10).join([f'โ€ข {source}' for source in unique_sources])}"
        else:
            formatted_response = f"{response_text}\n\n*Note: Sources information not available*"
        
        print(f"Formatted response: {formatted_response}")  
        return formatted_response
        
    except Exception as e:
        print(f"Error in query_rag: {e}")
        return f"I encountered an error while processing your query: {str(e)}"

def fallback_to_hf_client(prompt: str):
    """Fallback to HuggingFace client when Ollama is not available"""
    try:
        messages = [{"role": "user", "content": prompt}]
        response = ""
        for message in client.chat_completion(
            messages,
            max_tokens=512,
            stream=True,
            temperature=0.7,
            top_p=0.95,
        ):
            token = message.choices[0].delta.content
            if token:
                response += token
        return response
    except Exception as e:
        return f"Error generating response: {str(e)}"

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Use RAG for document-based queries
    try:
        rag_response = query_rag(message)
        # If RAG finds relevant information, return it
        if rag_response and not rag_response.startswith("I don't know") and not rag_response.startswith("I encountered an error"):
            yield rag_response
            return
    except Exception as e:
        print(f"RAG query failed: {e}")
    
    # Fallback to regular chat completion
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    try:
        for message_chunk in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = message_chunk.choices[0].delta.content
            if token:
                response += token
                yield response
    except Exception as e:
        yield f"Error: {str(e)}"



demo = gr.ChatInterface(
    respond,
    title="๐ŸŽ“ CS Query - RAG-Powered Academic Assistant",
    description="Ask questions about competence standards and get answers based on the uploaded academic documents.",
    chatbot=gr.Chatbot(height=500),  
    examples=[
        [
            "What are reasonable adjustments for students with disabilities?",
            DEFAULT_SYSTEM_MESSAGE,
            512,
            0.7,
            0.95
        ],
        [
            "What does the Equality Act say about education?",
            DEFAULT_SYSTEM_MESSAGE,
            512,
            0.7,
            0.95
        ]
    ],
)


if __name__ == "__main__":
    # main()
    demo.launch(
        inbrowser=True,  # Open in browser automatically
        height=800,  # Increase overall height
        width="100%",  # Use full width
    )