CompifAI / app.py
daniel-was-taken's picture
Refactor Dockerfile and compose.yml to uncomment entrypoint and CMD, and update environment variables for app service
84379c8
import os
from typing import Dict, List, Optional
from operator import itemgetter
from dotenv import load_dotenv
import chainlit as cl
from chainlit.types import ThreadDict
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from pydantic import SecretStr
# Load environment variables from .env file
load_dotenv()
from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_nebius import ChatNebius
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from pymilvus import MilvusClient
from sentence_transformers import SentenceTransformer
from langchain_nebius import NebiusEmbeddings
from chainlit.input_widget import Select, Switch, Slider
from langchain_core.documents import Document
from typing_extensions import List
# from populate_db import main
# Initialize Milvus client and embedding model
MILVUS_URI = os.getenv("MILVUS_URI", "http://localhost:19530")
milvus_client = MilvusClient(uri=MILVUS_URI, token=os.getenv("MILVUS_API_KEY"))
collection_name = "my_rag_collection"
# Initialize collection once at startup
if not milvus_client.has_collection(collection_name):
main()
else:
# Check if collection has data, populate if empty
stats = milvus_client.get_collection_stats(collection_name)
if stats['row_count'] == 0:
main()
milvus_client.load_collection(collection_name=collection_name)
# embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5")
# embedding_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
embedding_model = NebiusEmbeddings(
api_key=SecretStr(os.getenv("OPENAI_API_KEY")),
model="Qwen/Qwen3-Embedding-8B"
)
# Initialize LLM
model = ChatNebius(
model="Qwen/Qwen3-235B-A22B-Instruct-2507",
streaming=True,
temperature=0.2,
max_tokens=8192,
top_p=0.95,
api_key=SecretStr(os.getenv("OPENAI_API_KEY")),
)
# Define application steps
def emb_text(text: str) -> List[float]:
"""Generate embeddings for text using the sentence transformer model."""
return embedding_model.embed_query(text)
# return embedding_model.encode([text], normalize_embeddings=True).tolist()[0]
def retrieve_relevant_documents(query: str, limit: int = 5) -> List[Dict]:
"""Retrieve relevant documents from Milvus based on the query."""
try:
query_embedding = emb_text(query)
search_results = milvus_client.search(
collection_name=collection_name,
data=[query_embedding],
limit=limit,
output_fields=["text", "metadata"]
)
documents = []
for result in search_results[0]:
doc_info = {
"text": result['entity']['text'],
"metadata": result['entity']['metadata'],
"score": result['distance']
}
documents.append(doc_info)
return documents
except Exception as e:
print(f"Error retrieving documents: {e}")
return []
def format_docs_with_id(docs: List[Dict]) -> str:
formatted = []
for i, doc in enumerate(docs):
# Extract title and page_number from metadata, with fallbacks
metadata = doc.get('metadata', {})
title = metadata.get('filename', 'Unknown Document') # Use filename as fallback for title
page_number = metadata.get('page_number', 'Unknown')
score = doc.get('score', 'N/A') # Use score if available
text_content = doc.get('text', '')
formatted_doc = f"[{i+1}] Source: {title}, Page: {page_number}, Score: {score}\nContent: {text_content}"
formatted.append(formatted_doc)
print(f"Formatted documents: {formatted}")
return "\n\n".join(formatted)
def setup_rag_chain():
"""Setup the RAG chain with context retrieval."""
def get_context_and_history(inputs):
"""Retrieve context and get conversation history."""
query = inputs["question"]
relevant_docs = retrieve_relevant_documents(query, limit=5)
print("Relevant documents:", relevant_docs[0] if relevant_docs else "No documents found")
# Convert dictionaries to Document objects for LangChain
doc_objects = []
for doc in relevant_docs:
doc_obj = Document(
page_content=doc.get('text', ''),
metadata=doc.get('metadata', {})
)
doc_objects.append(doc_obj)
# Format citations for reference
citations = format_docs_with_id(relevant_docs)
# Add citations to the last document's metadata so it's available to the prompt
if doc_objects:
doc_objects[-1].metadata['formatted_citations'] = citations
return {
"question": query,
"context": doc_objects,
"history": cl.user_session.get("messages", [])
}
system_prompt = """You are an expert assistant for staff in UK higher education institutions. Help develop inclusive, non-discriminatory competence standards that comply with UK equality legislation (for example: the Equality Act 2010). Advise on reasonable adjustments and support to remove barriers and promote fairness for all students.
Rules:
1. Use ONLY the provided context documents as your source of information: {context}
2. If the context does not contain relevant information, respond exactly:
"I could not find relevant information about this topic in the provided documents."
3. Do not guess or include information from outside the provided documents.
4. Answer in clear, plain English. Define technical or legal terms when needed.
5. Provide practical, actionable guidance and examples for writing competence standards.
6. Emphasise removing barriers via reasonable adjustments and support; treat disability within the broader goal of equality and inclusivity.
7. Do not assume the user's prior knowledge; maintain a neutral, professional, and respectful tone.
Format requirements:
- Structure all responses using the RESPONSE TEMPLATE provided below.
- Use bolding for headers (e.g. **Summary**)
- Ensure there is a blank line before and after lists.
Response template:
**Summary**
[Insert a concise 1-3 sentence answer here]
**Key Guidance**
* [Actionable point 1]
* [Actionable point 2]
* [Actionable point 3]
"""
# Get the current settings to check if Think mode is enabled
settings = cl.user_session.get("settings", {})
use_think = settings.get("Think", True) # Default to True as per the initial setting
if not use_think:
system_prompt = '/no_think ' + system_prompt
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
])
question_answer_chain = create_stuff_documents_chain(model, prompt)
# Use a custom chain that properly handles our context and history
def process_input_and_format(inputs):
context_data = get_context_and_history(inputs)
return {
"context": context_data["context"],
"question": context_data["question"],
"history": context_data["history"]
}
chain = RunnableLambda(process_input_and_format) | question_answer_chain
return chain
# ============== Application Setup ==============
# Authentication
@cl.password_auth_callback
def auth(username: str, password: str) -> Optional[cl.User]:
if (username, password) == ("admin", os.getenv("PASSWORD")):
return cl.User(
identifier="admin",
metadata={"role": "admin", "provider": "credentials"}
)
return None
@cl.oauth_callback
def oauth_callback(
provider_id: str,
token: str,
raw_user_data: Dict[str, str],
default_user: cl.User,
) -> Optional[cl.PersistedUser]:
return default_user
# Starters
@cl.set_starters
async def set_starters():
return [
cl.Starter(
label="Reviewing Existing Standards",
message="How can we review existing competence standards to ensure they are inclusive?",
),
cl.Starter(
label="When No Adjustments are Possible",
message="What should we do if a competence standard cannot be adjusted for a student?",
),
]
# Chat lifecycle
@cl.on_chat_start
async def on_chat_start():
settings = await cl.ChatSettings(
[
Switch(id="Think", label="Use Deep Thinking", initial=True),
]
).send()
# Store initial settings
cl.user_session.set("settings", {"Think": True}) # Set the default value
"""Initialize chat session with RAG chain."""
chain = setup_rag_chain()
cl.user_session.set("chain", chain)
cl.user_session.set("messages", [])
@cl.on_settings_update
async def setup_agent(settings):
# print("on_settings_update", settings)
# Store the settings in the user session so they can be accessed in setup_rag_chain
cl.user_session.set("settings", settings)
# Update the chain with the new settings
chain = setup_rag_chain()
cl.user_session.set("chain", chain)
@cl.on_chat_resume
async def on_chat_resume(thread: ThreadDict):
"""Resume chat with conversation history."""
messages = []
root_messages = [m for m in thread["steps"] if m["parentId"] is None]
for message in root_messages:
if message["type"] == "user_message":
messages.append(HumanMessage(content=message["output"]))
else:
messages.append(AIMessage(content=message["output"]))
cl.user_session.set("messages", messages)
settings = await cl.ChatSettings(
[
Switch(id="Think", label="Use Deep Thinking", initial=True),
]
).send()
# Store initial settings
cl.user_session.set("settings", {"Think": True}) # Set the default value
# TODO: # Reinitialize the chain with the current settings
chain = setup_rag_chain()
cl.user_session.set("chain", chain)
@cl.on_message
async def on_message(message: cl.Message):
"""Handle incoming messages with RAG and conversation history."""
chain = cl.user_session.get("chain")
messages = cl.user_session.get("messages", [])
# 1. Initialize callback with stream_final_answer=True
# This automatically creates an empty message and streams tokens into it
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True,
)
try:
# Get relevant documents first (fast)
relevant_docs = retrieve_relevant_documents(message.content, limit=5)
citations = format_docs_with_id(relevant_docs)
# 2. Invoke the chain with the callback
# The chain will stream chunks to 'cb', which updates the UI in real-time
# We assign the final result to 'res' just to store it in history
answer = await chain.ainvoke(
{"question": message.content},
config=RunnableConfig(callbacks=[cb])
)
await cl.Message(answer).send()
# 'res' is usually a dict if the chain returns a dict, or a string.
# Based on your StrOutputParser usage, it should be a string.
# If your chain returns a dict, you might need to extract the text.
# answer = res if isinstance(res, str) else res.get("output", "") or res.get("text", "")
# 3. Add References as a Step (Collapsible element under the message)
# Note: Since the message is already sent by the callback, we just append a step.
async with cl.Step(name="References") as step:
if relevant_docs:
step.output = citations
else:
step.output = "No relevant documents found for this query."
# 4. Update History
messages.append(HumanMessage(content=message.content))
messages.append(AIMessage(content=answer))
cl.user_session.set("messages", messages)
except Exception as e:
await cl.Message(content=f"Sorry, I encountered an error: {str(e)}").send()