juliaturc commited on
Commit
2c96b06
·
1 Parent(s): c010627

Streaming responses.

Browse files
Files changed (1) hide show
  1. sage/chat.py +35 -6
sage/chat.py CHANGED
@@ -4,6 +4,7 @@ You must run `sage-index $GITHUB_REPO` first in order to index the codebase into
4
  """
5
 
6
  import argparse
 
7
  import os
8
 
9
  import gradio as gr
@@ -53,7 +54,8 @@ def build_rag_chain(args):
53
  ("human", "{input}"),
54
  ]
55
  )
56
- history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
 
57
 
58
  qa_system_prompt = (
59
  f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
@@ -136,21 +138,48 @@ def main():
136
 
137
  rag_chain = build_rag_chain(args)
138
 
139
- def _predict(message, history):
 
 
 
 
140
  """Performs one RAG operation."""
141
  history_langchain_format = []
142
  for human, ai in history:
143
  history_langchain_format.append(HumanMessage(content=human))
144
  history_langchain_format.append(AIMessage(content=ai))
145
  history_langchain_format.append(HumanMessage(content=message))
146
- response = rag_chain.invoke({"input": message, "chat_history": history_langchain_format})
147
- answer = append_sources_to_response(response)
148
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  gr.ChatInterface(
151
  _predict,
152
  title=args.repo_id,
153
- description=f"Code sage for your repo: {args.repo_id}",
154
  examples=["What does this repo do?", "Give me some sample code."],
155
  ).launch(share=args.share)
156
 
 
4
  """
5
 
6
  import argparse
7
+ import logging
8
  import os
9
 
10
  import gradio as gr
 
54
  ("human", "{input}"),
55
  ]
56
  )
57
+ contextualize_q_llm = llm.with_config(tags=["contextualize_q_llm"])
58
+ history_aware_retriever = create_history_aware_retriever(contextualize_q_llm, retriever, contextualize_q_prompt)
59
 
60
  qa_system_prompt = (
61
  f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
 
138
 
139
  rag_chain = build_rag_chain(args)
140
 
141
+ def source_md(file_path: str, url: str) -> str:
142
+ """Formats a context source in Markdown."""
143
+ return f"[{file_path}]({url})"
144
+
145
+ async def _predict(message, history):
146
  """Performs one RAG operation."""
147
  history_langchain_format = []
148
  for human, ai in history:
149
  history_langchain_format.append(HumanMessage(content=human))
150
  history_langchain_format.append(AIMessage(content=ai))
151
  history_langchain_format.append(HumanMessage(content=message))
152
+
153
+ query_rewrite = ""
154
+ response = ""
155
+ async for event in rag_chain.astream_events(
156
+ {
157
+ "input": message,
158
+ "chat_history": history_langchain_format,
159
+ },
160
+ version="v1",
161
+ ):
162
+ if event["name"] == "retrieve_documents" and "output" in event["data"]:
163
+ sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in event["data"]["output"]]
164
+ # Deduplicate while preserving the order.
165
+ sources = list(dict.fromkeys(sources))
166
+ response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in sources]) + "\n## Response:\n"
167
+
168
+ elif event["event"] == "on_chat_model_stream":
169
+ chunk = event["data"]["chunk"].content
170
+
171
+ if "contextualize_q_llm" in event["tags"]:
172
+ query_rewrite += chunk
173
+ else:
174
+ # This is the actual response to the user query.
175
+ if not response:
176
+ logging.info(f"Query rewrite: {query_rewrite}")
177
+ response += chunk
178
+ yield response
179
 
180
  gr.ChatInterface(
181
  _predict,
182
  title=args.repo_id,
 
183
  examples=["What does this repo do?", "Give me some sample code."],
184
  ).launch(share=args.share)
185