Flutra commited on
Commit
66c3a38
·
1 Parent(s): 0760d7d

add streaming and memory management

Browse files
Files changed (1) hide show
  1. app.py +69 -45
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import os
3
  from langchain.chains import ConversationalRetrievalChain
@@ -5,19 +6,26 @@ from langchain.memory import ConversationBufferMemory
5
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
6
  from langchain.prompts import PromptTemplate
7
  from langchain_community.vectorstores import Chroma
 
 
8
 
9
- def create_qa_chain():
 
 
 
 
 
 
 
10
  """
11
- Create the QA chain with the loaded vectorstore
12
  """
13
- # Initialize embeddings and load vectorstore
14
  embeddings = OpenAIEmbeddings()
15
  vectorstore = Chroma(
16
  persist_directory="./vectorstore",
17
  embedding_function=embeddings
18
  )
19
 
20
- # Set up retriever
21
  retriever = vectorstore.as_retriever(
22
  search_type="mmr",
23
  search_kwargs={
@@ -27,14 +35,12 @@ def create_qa_chain():
27
  }
28
  )
29
 
30
- # Set up memory
31
  memory = ConversationBufferMemory(
32
  memory_key="chat_history",
33
  return_messages=True,
34
  output_key='answer'
35
  )
36
 
37
- # Create prompt template
38
  qa_prompt = PromptTemplate.from_template("""You are an expert technical writer specializing in API documentation.
39
  When describing API endpoints, structure your response in this exact format:
40
 
@@ -61,11 +67,12 @@ Question: {question}
61
 
62
  Technical answer (following the exact structure above):""")
63
 
64
- # Create the chain
65
  qa_chain = ConversationalRetrievalChain.from_llm(
66
  llm=ChatOpenAI(
67
  temperature=0.1,
68
- model_name="gpt-4-turbo-preview"
 
 
69
  ),
70
  retriever=retriever,
71
  memory=memory,
@@ -76,50 +83,67 @@ Technical answer (following the exact structure above):""")
76
 
77
  return qa_chain
78
 
79
- def chat(message, history):
80
  """
81
- Process chat messages and return responses
82
  """
83
- # Get or create QA chain
84
- if not hasattr(chat, 'qa_chain'):
85
- chat.qa_chain = create_qa_chain()
86
-
87
- # Get response
88
- result = chat.qa_chain({"question": message})
89
 
90
- # Format sources
91
- sources = "\n\nSources:\n"
92
- seen_components = set()
93
- shown_sources = 0
94
 
95
- for doc in result["source_documents"]:
96
- component = doc.metadata.get('component', '')
97
- title = doc.metadata.get('title', '')
98
- combo = (component, title)
 
 
 
99
 
100
- if combo not in seen_components and shown_sources < 3:
101
- seen_components.add(combo)
102
- shown_sources += 1
103
- sources += f"\nSource {shown_sources}:\n"
104
- sources += f"Title: {title}\n"
105
- sources += f"Component: {component}\n"
106
- sources += f"Content: {doc.page_content[:300]}...\n"
 
 
 
 
 
 
 
 
 
107
 
108
- # Combine response with sources
109
- full_response = result["answer"] + sources
 
110
 
111
- return full_response
 
 
 
 
 
 
 
112
 
113
- demo = gr.ChatInterface(
114
- chat,
115
- title="Apple Music API Documentation Assistant",
116
- description="Ask questions about the Apple Music API documentation.",
117
- examples=[
118
- "How to search for songs on Apple Music API?",
119
- "What are the required parameters for searching songs?",
120
- "Show me an example request with all parameters"
121
- ]
122
- )
 
 
123
 
124
  if __name__ == "__main__":
125
- demo.launch()
 
 
1
+ ```python
2
  import gradio as gr
3
  import os
4
  from langchain.chains import ConversationalRetrievalChain
 
6
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
7
  from langchain.prompts import PromptTemplate
8
  from langchain_community.vectorstores import Chroma
9
+ from queue import Queue
10
+ from threading import Thread
11
 
12
+ class StreamHandler:
13
+ def __init__(self, queue):
14
+ self.queue = queue
15
+
16
+ def on_llm_new_token(self, token):
17
+ self.queue.put(token)
18
+
19
+ def create_qa_chain(streaming_handler=None):
20
  """
21
+ Create the QA chain with streaming capability
22
  """
 
23
  embeddings = OpenAIEmbeddings()
24
  vectorstore = Chroma(
25
  persist_directory="./vectorstore",
26
  embedding_function=embeddings
27
  )
28
 
 
29
  retriever = vectorstore.as_retriever(
30
  search_type="mmr",
31
  search_kwargs={
 
35
  }
36
  )
37
 
 
38
  memory = ConversationBufferMemory(
39
  memory_key="chat_history",
40
  return_messages=True,
41
  output_key='answer'
42
  )
43
 
 
44
  qa_prompt = PromptTemplate.from_template("""You are an expert technical writer specializing in API documentation.
45
  When describing API endpoints, structure your response in this exact format:
46
 
 
67
 
68
  Technical answer (following the exact structure above):""")
69
 
 
70
  qa_chain = ConversationalRetrievalChain.from_llm(
71
  llm=ChatOpenAI(
72
  temperature=0.1,
73
+ model_name="gpt-4-turbo-preview",
74
+ streaming=True,
75
+ callbacks=[streaming_handler] if streaming_handler else None
76
  ),
77
  retriever=retriever,
78
  memory=memory,
 
83
 
84
  return qa_chain
85
 
86
+ def predict(message, history):
87
  """
88
+ Process each message with streaming
89
  """
90
+ token_queue = Queue()
91
+ stream_handler = StreamHandler(token_queue)
 
 
 
 
92
 
93
+ # Create new QA chain for each conversation to ensure fresh memory
94
+ qa_chain = create_qa_chain(stream_handler)
 
 
95
 
96
+ # Function to process the message and add to queue
97
+ def get_response():
98
+ result = qa_chain({"question": message})
99
+ # Add sources to queue
100
+ sources = "\n\nSources:\n"
101
+ seen_components = set()
102
+ shown_sources = 0
103
 
104
+ for doc in result["source_documents"]:
105
+ component = doc.metadata.get('component', '')
106
+ title = doc.metadata.get('title', '')
107
+ combo = (component, title)
108
+
109
+ if combo not in seen_components and shown_sources < 3:
110
+ seen_components.add(combo)
111
+ shown_sources += 1
112
+ sources += f"\nSource {shown_sources}:\n"
113
+ sources += f"Title: {title}\n"
114
+ sources += f"Component: {component}\n"
115
+ sources += f"Content: {doc.page_content[:300]}...\n"
116
+
117
+ for char in sources:
118
+ token_queue.put(char)
119
+ token_queue.put(None) # Signal end of response
120
 
121
+ # Start processing in a separate thread
122
+ thread = Thread(target=get_response)
123
+ thread.start()
124
 
125
+ # Stream the response
126
+ response = ""
127
+ while True:
128
+ token = token_queue.get()
129
+ if token is None:
130
+ break
131
+ response += token
132
+ yield response
133
 
134
+ # Create the Gradio interface
135
+ with gr.Blocks() as demo:
136
+ chatbot = gr.ChatInterface(
137
+ predict,
138
+ title="Apple Music API Documentation Assistant",
139
+ description="Ask questions about the Apple Music API documentation.",
140
+ examples=[
141
+ "How to search for songs on Apple Music API?",
142
+ "What are the required parameters for searching songs?",
143
+ "Show me an example request with all parameters"
144
+ ]
145
+ )
146
 
147
  if __name__ == "__main__":
148
+ demo.queue().launch()
149
+ ```