pcschreiber1 commited on
Commit
5cacb85
·
1 Parent(s): 627ec3c

Temporarily moved graph to gradio file.

Browse files
Files changed (2) hide show
  1. app.py +56 -1
  2. conversation/main.py +79 -79
app.py CHANGED
@@ -1,15 +1,22 @@
1
  from typing import Any
2
 
3
  import gradio as gr
 
 
4
  from langchain_openai import OpenAIEmbeddings
5
  from langchain_qdrant import QdrantVectorStore
 
 
 
 
6
  import structlog
7
 
8
  from qdrant_client import QdrantClient
9
  from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
10
 
11
  import logging_config as _
12
- from conversation.main import graph
 
13
  from ingestion.main import ingest_document
14
 
15
  from config import app_settings
@@ -22,6 +29,12 @@ embeddings = OpenAIEmbeddings(
22
  api_key=app_settings.llm_api_key
23
  )
24
 
 
 
 
 
 
 
25
  client = QdrantClient(app_settings.vector_db_url)
26
  if not client.collection_exists(app_settings.vector_db_collection_name):
27
  client.create_collection(
@@ -29,13 +42,55 @@ if not client.collection_exists(app_settings.vector_db_collection_name):
29
  vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
30
  sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
31
  )
 
32
  vector_store = QdrantVectorStore(
33
  client=client,
34
  collection_name=app_settings.vector_db_collection_name,
35
  embedding=embeddings,
36
  )
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
39
 
40
  with open("static/style.css", "r") as f:
41
  css = f.read()
 
1
  from typing import Any
2
 
3
  import gradio as gr
4
+ from langchain.chat_models import init_chat_model
5
+ from langchain_core.tools import tool
6
  from langchain_openai import OpenAIEmbeddings
7
  from langchain_qdrant import QdrantVectorStore
8
+ from langgraph.checkpoint.memory import MemorySaver
9
+ from langgraph.graph import MessagesState, StateGraph, END
10
+ from langgraph.prebuilt import ToolNode, tools_condition
11
+ from langgraph.prebuilt import ToolNode
12
  import structlog
13
 
14
  from qdrant_client import QdrantClient
15
  from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
16
 
17
  import logging_config as _
18
+ # from conversation.main import graph
19
+ from conversation.generate import generate
20
  from ingestion.main import ingest_document
21
 
22
  from config import app_settings
 
29
  api_key=app_settings.llm_api_key
30
  )
31
 
32
+ llm = init_chat_model(
33
+ app_settings.llm_model,
34
+ model_provider="openai",
35
+ api_key=app_settings.llm_api_key
36
+ )
37
+
38
  client = QdrantClient(app_settings.vector_db_url)
39
  if not client.collection_exists(app_settings.vector_db_collection_name):
40
  client.create_collection(
 
42
  vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
43
  sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
44
  )
45
+ # TODO: move to LLM files later
46
  vector_store = QdrantVectorStore(
47
  client=client,
48
  collection_name=app_settings.vector_db_collection_name,
49
  embedding=embeddings,
50
  )
51
 
52
+ # ------
53
+ # Move to `conversation/main`` later
54
+ @tool(response_format="content_and_artifact")
55
+ def retrieve(query: str):
56
+ """Retrieve information related to a query."""
57
+ retrieved_docs = vector_store.similarity_search(query, k=2)
58
+ serialized = "\n\n".join(
59
+ (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
60
+ for doc in retrieved_docs
61
+ )
62
+ return serialized, retrieved_docs
63
+
64
+
65
+ def query_or_respond(state: MessagesState):
66
+ """Generate tool call for retrieval or respond."""
67
+ llm_with_tools = llm.bind_tools([retrieve])
68
+ response = llm_with_tools.invoke(state["messages"])
69
+ # MessagesState appends messages to state instead of overwriting
70
+ return {"messages": [response]}
71
+
72
+
73
+ graph_builder = StateGraph(MessagesState)
74
+ tools = ToolNode([retrieve])
75
+ memory = MemorySaver()
76
+
77
+ graph_builder.add_node(query_or_respond)
78
+ graph_builder.add_node(tools)
79
+ graph_builder.add_node(generate)
80
+
81
+ graph_builder.set_entry_point("query_or_respond")
82
+
83
+ graph_builder.add_conditional_edges(
84
+ "query_or_respond",
85
+ tools_condition,
86
+ {END: END, "tools": "tools"},
87
+ )
88
+
89
+ graph_builder.add_edge("tools", "generate")
90
+ graph_builder.add_edge("generate", END)
91
 
92
+ graph = graph_builder.compile(checkpointer=memory)
93
+ # -----
94
 
95
  with open("static/style.css", "r") as f:
96
  css = f.read()
conversation/main.py CHANGED
@@ -1,79 +1,79 @@
1
- from langchain.chat_models import init_chat_model
2
- from langchain_core.tools import tool
3
- from langchain_openai import OpenAIEmbeddings
4
- from langchain_qdrant import QdrantVectorStore
5
- from langgraph.checkpoint.memory import MemorySaver
6
- from langgraph.graph import MessagesState, StateGraph, END
7
- from langgraph.prebuilt import ToolNode, tools_condition
8
- from langgraph.prebuilt import ToolNode
9
- from qdrant_client import QdrantClient
10
- from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
11
-
12
- from config import app_settings
13
- from conversation.generate import generate
14
-
15
-
16
- llm = init_chat_model(
17
- app_settings.llm_model,
18
- model_provider="openai",
19
- api_key=app_settings.llm_api_key
20
- )
21
-
22
- embeddings = OpenAIEmbeddings(
23
- model=app_settings.embedding_model,
24
- api_key=app_settings.llm_api_key
25
- )
26
-
27
- client = QdrantClient(app_settings.vector_db_url)
28
- if not client.collection_exists(app_settings.vector_db_collection_name):
29
- client.create_collection(
30
- collection_name=app_settings.vector_db_collection_name,
31
- vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
32
- sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
33
- )
34
-
35
- vector_store = QdrantVectorStore(
36
- client=client,
37
- collection_name=app_settings.vector_db_collection_name,
38
- embedding=embeddings,
39
- )
40
-
41
- @tool(response_format="content_and_artifact")
42
- def retrieve(query: str):
43
- """Retrieve information related to a query."""
44
- retrieved_docs = vector_store.similarity_search(query, k=2)
45
- serialized = "\n\n".join(
46
- (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
47
- for doc in retrieved_docs
48
- )
49
- return serialized, retrieved_docs
50
-
51
-
52
- def query_or_respond(state: MessagesState):
53
- """Generate tool call for retrieval or respond."""
54
- llm_with_tools = llm.bind_tools([retrieve])
55
- response = llm_with_tools.invoke(state["messages"])
56
- # MessagesState appends messages to state instead of overwriting
57
- return {"messages": [response]}
58
-
59
-
60
- graph_builder = StateGraph(MessagesState)
61
- tools = ToolNode([retrieve])
62
- memory = MemorySaver()
63
-
64
- graph_builder.add_node(query_or_respond)
65
- graph_builder.add_node(tools)
66
- graph_builder.add_node(generate)
67
-
68
- graph_builder.set_entry_point("query_or_respond")
69
-
70
- graph_builder.add_conditional_edges(
71
- "query_or_respond",
72
- tools_condition,
73
- {END: END, "tools": "tools"},
74
- )
75
-
76
- graph_builder.add_edge("tools", "generate")
77
- graph_builder.add_edge("generate", END)
78
-
79
- graph = graph_builder.compile(checkpointer=memory)
 
1
+ # from langchain.chat_models import init_chat_model
2
+ # from langchain_core.tools import tool
3
+ # from langchain_openai import OpenAIEmbeddings
4
+ # from langchain_qdrant import QdrantVectorStore
5
+ # from langgraph.checkpoint.memory import MemorySaver
6
+ # from langgraph.graph import MessagesState, StateGraph, END
7
+ # from langgraph.prebuilt import ToolNode, tools_condition
8
+ # from langgraph.prebuilt import ToolNode
9
+ # from qdrant_client import QdrantClient
10
+ # from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
11
+
12
+ # from config import app_settings
13
+ # from conversation.generate import generate
14
+
15
+
16
+ # llm = init_chat_model(
17
+ # app_settings.llm_model,
18
+ # model_provider="openai",
19
+ # api_key=app_settings.llm_api_key
20
+ # )
21
+
22
+ # embeddings = OpenAIEmbeddings(
23
+ # model=app_settings.embedding_model,
24
+ # api_key=app_settings.llm_api_key
25
+ # )
26
+
27
+ # client = QdrantClient(app_settings.vector_db_url)
28
+ # if not client.collection_exists(app_settings.vector_db_collection_name):
29
+ # client.create_collection(
30
+ # collection_name=app_settings.vector_db_collection_name,
31
+ # vectors_config=VectorParams(size=app_settings.embedding_size, distance=Distance.COSINE),
32
+ # sparse_vectors_config={'langchain-sparse': SparseVectorParams(index=None, modifier=None)}
33
+ # )
34
+
35
+ # vector_store = QdrantVectorStore(
36
+ # client=client,
37
+ # collection_name=app_settings.vector_db_collection_name,
38
+ # embedding=embeddings,
39
+ # )
40
+
41
+ # @tool(response_format="content_and_artifact")
42
+ # def retrieve(query: str):
43
+ # """Retrieve information related to a query."""
44
+ # retrieved_docs = vector_store.similarity_search(query, k=2)
45
+ # serialized = "\n\n".join(
46
+ # (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
47
+ # for doc in retrieved_docs
48
+ # )
49
+ # return serialized, retrieved_docs
50
+
51
+
52
+ # def query_or_respond(state: MessagesState):
53
+ # """Generate tool call for retrieval or respond."""
54
+ # llm_with_tools = llm.bind_tools([retrieve])
55
+ # response = llm_with_tools.invoke(state["messages"])
56
+ # # MessagesState appends messages to state instead of overwriting
57
+ # return {"messages": [response]}
58
+
59
+
60
+ # graph_builder = StateGraph(MessagesState)
61
+ # tools = ToolNode([retrieve])
62
+ # memory = MemorySaver()
63
+
64
+ # graph_builder.add_node(query_or_respond)
65
+ # graph_builder.add_node(tools)
66
+ # graph_builder.add_node(generate)
67
+
68
+ # graph_builder.set_entry_point("query_or_respond")
69
+
70
+ # graph_builder.add_conditional_edges(
71
+ # "query_or_respond",
72
+ # tools_condition,
73
+ # {END: END, "tools": "tools"},
74
+ # )
75
+
76
+ # graph_builder.add_edge("tools", "generate")
77
+ # graph_builder.add_edge("generate", END)
78
+
79
+ # graph = graph_builder.compile(checkpointer=memory)