pcschreiber1 commited on
Commit
627ec3c
·
1 Parent(s): 2eeb3a4

Basic set-up.

Browse files
.gitignore CHANGED
@@ -9,6 +9,12 @@ wheels/
9
  # Virtual environments
10
  .venv
11
 
 
 
 
12
  # application files
13
  .gradio
14
- *.log
 
 
 
 
9
  # Virtual environments
10
  .venv
11
 
12
+ # secrets
13
+ *.env
14
+
15
  # application files
16
  .gradio
17
+ *.log
18
+
19
+ # sanboxes
20
+ *.ipynb
README.md CHANGED
@@ -50,5 +50,9 @@ Alternatively, with a different dependency manager such as `venv` install direct
50
 
51
 
52
 
 
 
 
 
53
 
54
 
 
50
 
51
 
52
 
53
+ ## To-Do
54
+ - create bsaic set-up with Qdrant in memory, ingestion pipeline
55
+ - create basic set-up where retrieval is not yet having an llm call
56
+ - connect to frontend with "invoke" and deletion of memory
57
 
58
 
app.py CHANGED
@@ -1,13 +1,42 @@
1
  from typing import Any
2
 
3
  import gradio as gr
 
 
4
  import structlog
5
 
 
 
 
6
  import logging_config as _
 
 
 
 
7
 
8
  # Create a logger instance
9
  logger = structlog.get_logger(__name__)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  with open("static/style.css", "r") as f:
12
  css = f.read()
13
  logger.info("Successfully loaded styles.")
@@ -18,23 +47,29 @@ def bot(message, history) -> list[Any]:
18
 
19
  With multi-modal inputs text and each file is treated as separate message.
20
  """
 
21
  logger.info("This is the history", history=history)
22
 
23
  # enable message edit
24
  if isinstance(message, str):
25
  message = {"text": message}
 
 
 
 
 
26
 
27
  # create text response
28
- response = []
29
- response.append("You wrote: '" + message.get("text") + "' and uploaded:")
 
 
 
 
30
 
31
- # display files (exemplary)
32
- if message.get("files"):
33
- for file in message.get("files"):
34
- response.append(gr.File(value=file))
35
 
36
- logger.info(response=response)
37
- return response
38
 
39
 
40
  demo = gr.ChatInterface(
 
1
  from typing import Any
2
 
3
  import gradio as gr
4
+ from langchain_openai import OpenAIEmbeddings
5
+ from langchain_qdrant import QdrantVectorStore
6
  import structlog
7
 
8
+ from qdrant_client import QdrantClient
9
+ from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
10
+
11
  import logging_config as _
12
+ from conversation.main import graph
13
+ from ingestion.main import ingest_document
14
+
15
+ from config import app_settings
16
 
17
  # Create a logger instance
18
  logger = structlog.get_logger(__name__)
19
 
20
+ embeddings = OpenAIEmbeddings(
21
+ model=app_settings.embedding_model,
22
+ api_key=app_settings.llm_api_key
23
+ )
24
+
25
+ client = QdrantClient(app_settings.vector_db_url)
26
+ if not client.collection_exists(app_settings.vector_db_collection_name):
27
+ client.create_collection(
28
+ collection_name=app_settings.vector_db_collection_name,
29
+ vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
30
+ sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
31
+ )
32
+ vector_store = QdrantVectorStore(
33
+ client=client,
34
+ collection_name=app_settings.vector_db_collection_name,
35
+ embedding=embeddings,
36
+ )
37
+
38
+
39
+
40
  with open("static/style.css", "r") as f:
41
  css = f.read()
42
  logger.info("Successfully loaded styles.")
 
47
 
48
  With multi-modal inputs text and each file is treated as separate message.
49
  """
50
+
51
  logger.info("This is the history", history=history)
52
 
53
  # enable message edit
54
  if isinstance(message, str):
55
  message = {"text": message}
56
+
57
+ # process files
58
+ for file in message.get("files"):
59
+ logger.info("Received file", file=file)
60
+ ingest_document(file, vector_store)
61
 
62
  # create text response
63
+ # TODO: see how state can be set in chat interface
64
+ config = {"configurable": {"thread_id": "abc123"}}
65
+ response = graph.invoke(
66
+ {"messages": [{"role": "user", "content": message.get("text")}]},
67
+ config=config,
68
+ )
69
 
70
+ logger.info("Generated a response", response=response)
 
 
 
71
 
72
+ return [response["messages"][-1].content]
 
73
 
74
 
75
  demo = gr.ChatInterface(
config.py CHANGED
@@ -31,6 +31,7 @@ class AppSettings(Settings):
31
  llm_model: str
32
  embedding_url: str
33
  embedding_model: str
 
34
  vector_db_url: str
35
  vector_db_collection_name: str
36
 
@@ -42,4 +43,6 @@ class AppSettings(Settings):
42
  env_file=".env",
43
  env_file_encoding="utf-8",
44
  extra="ignore",
45
- )
 
 
 
31
  llm_model: str
32
  embedding_url: str
33
  embedding_model: str
34
+ embedding_size: int
35
  vector_db_url: str
36
  vector_db_collection_name: str
37
 
 
43
  env_file=".env",
44
  env_file_encoding="utf-8",
45
  extra="ignore",
46
+ )
47
+
48
+ app_settings = AppSettings()
conversation/__init__.py ADDED
File without changes
conversation/generate.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import structlog
2
+ from langchain.chat_models import init_chat_model
3
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompts import (
6
+ ChatPromptTemplate,
7
+ MessagesPlaceholder,
8
+ )
9
+ from langchain_core.runnables import RunnableParallel
10
+ from langgraph.graph import MessagesState
11
+ from pydantic import BaseModel
12
+
13
+ from config import app_settings
14
+
15
+ logger = structlog.get_logger(__name__)
16
+
17
+ llm = init_chat_model(
18
+ app_settings.llm_model,
19
+ model_provider="openai",
20
+ api_key=app_settings.llm_api_key
21
+ )
22
+
23
+
24
+ # RAG answer synthesis prompt
25
+ system_template = """
26
+ Answer the user's questions based on the below context.
27
+ If the context doesn't contain any relevant information to the question, don't make something up and just say "I don't know":
28
+
29
+ <context>
30
+ {context}
31
+ </context>
32
+ """
33
+
34
+ ANSWER_PROMPT = ChatPromptTemplate.from_messages(
35
+ [
36
+ ("system", system_template),
37
+ MessagesPlaceholder(variable_name="chat_history"),
38
+ ("user", "{question}"),
39
+ ]
40
+ )
41
+
42
+
43
+ # User input
44
+ class ChatHistory(BaseModel):
45
+ chat_history: list[AIMessage | HumanMessage]
46
+ question: str
47
+
48
+
49
+ _inputs = RunnableParallel(
50
+ {
51
+ "question": lambda x: x["question"],
52
+ # "chat_history": lambda x: _format_chat_history(x["chat_history"]),
53
+ "chat_history": lambda x: x["chat_history"],
54
+ "context": lambda x: x["context"]
55
+ }
56
+ ).with_types(input_type=ChatHistory)
57
+
58
+ chain = _inputs | ANSWER_PROMPT | llm | StrOutputParser()
59
+
60
+
61
+ def generate(state: MessagesState):
62
+ """Generate answer."""
63
+ # Get generated ToolMessages
64
+ recent_tool_messages = []
65
+ for message in reversed(state["messages"]):
66
+ if message.type == "tool":
67
+ recent_tool_messages.append(message)
68
+ else:
69
+ break
70
+ tool_messages = recent_tool_messages[::-1]
71
+ # Format into prompt
72
+ docs_content = "\n\n".join(doc.content for doc in tool_messages)
73
+ logger.info("Tool messages", context=docs_content)
74
+
75
+ conversation_messages = [
76
+ message
77
+ for message in state["messages"]
78
+ if message.type in ("human", "system")
79
+ or (message.type == "ai" and not message.tool_calls)
80
+ ]
81
+ response = chain.invoke({
82
+ "question": conversation_messages[-1].content,
83
+ "chat_history": conversation_messages,
84
+ "context": docs_content,
85
+ })
86
+ return {"messages": [response]}
conversation/main.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chat_models import init_chat_model
2
+ from langchain_core.tools import tool
3
+ from langchain_openai import OpenAIEmbeddings
4
+ from langchain_qdrant import QdrantVectorStore
5
+ from langgraph.checkpoint.memory import MemorySaver
6
+ from langgraph.graph import MessagesState, StateGraph, END
7
+ from langgraph.prebuilt import ToolNode, tools_condition
8
+ from langgraph.prebuilt import ToolNode
9
+ from qdrant_client import QdrantClient
10
+ from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
11
+
12
+ from config import app_settings
13
+ from conversation.generate import generate
14
+
15
+
16
+ llm = init_chat_model(
17
+ app_settings.llm_model,
18
+ model_provider="openai",
19
+ api_key=app_settings.llm_api_key
20
+ )
21
+
22
+ embeddings = OpenAIEmbeddings(
23
+ model=app_settings.embedding_model,
24
+ api_key=app_settings.llm_api_key
25
+ )
26
+
27
+ client = QdrantClient(app_settings.vector_db_url)
28
+ if not client.collection_exists(app_settings.vector_db_collection_name):
29
+ client.create_collection(
30
+ collection_name=app_settings.vector_db_collection_name,
31
+ vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
32
+ sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
33
+ )
34
+
35
+ vector_store = QdrantVectorStore(
36
+ client=client,
37
+ collection_name=app_settings.vector_db_collection_name,
38
+ embedding=embeddings,
39
+ )
40
+
41
+ @tool(response_format="content_and_artifact")
42
+ def retrieve(query: str):
43
+ """Retrieve information related to a query."""
44
+ retrieved_docs = vector_store.similarity_search(query, k=2)
45
+ serialized = "\n\n".join(
46
+ (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
47
+ for doc in retrieved_docs
48
+ )
49
+ return serialized, retrieved_docs
50
+
51
+
52
+ def query_or_respond(state: MessagesState):
53
+ """Generate tool call for retrieval or respond."""
54
+ llm_with_tools = llm.bind_tools([retrieve])
55
+ response = llm_with_tools.invoke(state["messages"])
56
+ # MessagesState appends messages to state instead of overwriting
57
+ return {"messages": [response]}
58
+
59
+
60
+ graph_builder = StateGraph(MessagesState)
61
+ tools = ToolNode([retrieve])
62
+ memory = MemorySaver()
63
+
64
+ graph_builder.add_node(query_or_respond)
65
+ graph_builder.add_node(tools)
66
+ graph_builder.add_node(generate)
67
+
68
+ graph_builder.set_entry_point("query_or_respond")
69
+
70
+ graph_builder.add_conditional_edges(
71
+ "query_or_respond",
72
+ tools_condition,
73
+ {END: END, "tools": "tools"},
74
+ )
75
+
76
+ graph_builder.add_edge("tools", "generate")
77
+ graph_builder.add_edge("generate", END)
78
+
79
+ graph = graph_builder.compile(checkpointer=memory)
ingestion/__init__.py ADDED
File without changes
ingestion/main.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import structlog
2
+
3
+ from langchain_community.document_loaders import PDFPlumberLoader
4
+ from langchain_openai import OpenAIEmbeddings
5
+ from langchain_qdrant import QdrantVectorStore
6
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
9
+
10
+ from config import app_settings
11
+
12
+ logger = structlog.get_logger(__name__)
13
+
14
+ # embeddings = OpenAIEmbeddings(
15
+ # model=app_settings.embedding_model,
16
+ # api_key=app_settings.llm_api_key
17
+ # )
18
+
19
+ # client = QdrantClient(app_settings.vector_db_url)
20
+ # if not client.collection_exists(app_settings.vector_db_collection_name):
21
+ # client.create_collection(
22
+ # collection_name=app_settings.vector_db_collection_name,
23
+ # vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
24
+ # sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
25
+ # )
26
+ # vector_store = QdrantVectorStore(
27
+ # client=client,
28
+ # collection_name=app_settings.vector_db_collection_name,
29
+ # embedding=embeddings,
30
+ # )
31
+
32
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
33
+
34
+ def ingest_document(path, vector_store):
35
+ logger.info("Load document", path=path)
36
+ loader = PDFPlumberLoader(path)
37
+ docs = loader.load()
38
+ logger.info("Successfully loaded document", path=path)
39
+ all_splits = text_splitter.split_documents(docs)
40
+ _ = vector_store.add_documents(documents=all_splits)
41
+ logger.info("Successfully uploaded to vectorstore", path=path)
pyproject.toml CHANGED
@@ -6,16 +6,25 @@ readme = "README.md"
6
  requires-python = ">=3.13"
7
  dependencies = [
8
  "gradio>=5.33.0",
 
 
 
 
 
 
 
9
  "pydantic>=2.11.5",
10
  "pydantic-settings>=2.9.1",
 
11
  "structlog>=25.4.0",
12
  ]
13
 
14
 
15
  [tool.app_config]
16
  # shared
17
- llm_model = "gpt-4o"
18
- embedding_model = "BAAI/bge-m3"
19
  embedding_url = "http://tei:80"
 
20
  vector_db_url = ":memory:"
21
- vector_db_collection_name = "bgem3_store_multilingual"
 
6
  requires-python = ">=3.13"
7
  dependencies = [
8
  "gradio>=5.33.0",
9
+ "ipykernel>=6.29.5",
10
+ "langchain-community>=0.3.24",
11
+ "langchain-openai>=0.3.21",
12
+ "langchain-qdrant>=0.2.0",
13
+ "langchain-text-splitters>=0.3.8",
14
+ "langgraph>=0.4.8",
15
+ "pdfplumber>=0.11.6",
16
  "pydantic>=2.11.5",
17
  "pydantic-settings>=2.9.1",
18
+ "qdrant-client>=1.14.2",
19
  "structlog>=25.4.0",
20
  ]
21
 
22
 
23
  [tool.app_config]
24
  # shared
25
+ llm_model = "gpt-4o-mini"
26
+ embedding_model = "text-embedding-3-large"
27
  embedding_url = "http://tei:80"
28
+ embedding_size = 3072
29
  vector_db_url = ":memory:"
30
+ vector_db_collection_name = "test_collection"
uv.lock CHANGED
The diff for this file is too large to render. See raw diff