Flutra commited on
Commit
5286c9a
·
1 Parent(s): c975a90

basic model

Browse files
Files changed (1) hide show
  1. app.py +49 -76
app.py CHANGED
@@ -5,26 +5,19 @@ from langchain.memory import ConversationBufferMemory
5
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
6
  from langchain.prompts import PromptTemplate
7
  from langchain_community.vectorstores import Chroma
8
- from queue import Queue
9
- from threading import Thread
10
 
11
- class StreamHandler:
12
- def __init__(self, queue):
13
- self.queue = queue
14
-
15
- def on_llm_new_token(self, token):
16
- self.queue.put(token)
17
-
18
- def create_qa_chain(streaming_handler=None):
19
  """
20
- Create the QA chain with streaming capability
21
  """
 
22
  embeddings = OpenAIEmbeddings()
23
  vectorstore = Chroma(
24
  persist_directory="./vectorstore",
25
  embedding_function=embeddings
26
  )
27
 
 
28
  retriever = vectorstore.as_retriever(
29
  search_type="mmr",
30
  search_kwargs={
@@ -34,12 +27,14 @@ def create_qa_chain(streaming_handler=None):
34
  }
35
  )
36
 
 
37
  memory = ConversationBufferMemory(
38
  memory_key="chat_history",
39
  return_messages=True,
40
  output_key='answer'
41
  )
42
 
 
43
  qa_prompt = PromptTemplate.from_template("""You are an expert technical writer specializing in API documentation.
44
  When describing API endpoints, structure your response in this exact format:
45
 
@@ -66,12 +61,11 @@ Question: {question}
66
 
67
  Technical answer (following the exact structure above):""")
68
 
 
69
  qa_chain = ConversationalRetrievalChain.from_llm(
70
  llm=ChatOpenAI(
71
  temperature=0.1,
72
- model_name="gpt-4-turbo-preview",
73
- streaming=True,
74
- callbacks=[streaming_handler] if streaming_handler else None
75
  ),
76
  retriever=retriever,
77
  memory=memory,
@@ -82,75 +76,54 @@ Technical answer (following the exact structure above):""")
82
 
83
  return qa_chain
84
 
85
- def predict(message, history):
86
  """
87
- Process each message with streaming and handle quit commands
88
  """
89
- # Check for quit command
90
- if message.lower() in ['quit', 'exit', 'bye']:
91
- # Clear the QA chain if it exists
92
- if hasattr(predict, 'qa_chain'):
93
- predict.qa_chain.memory.clear()
94
- delattr(predict, 'qa_chain')
95
- return "Chat history cleared. Goodbye!"
96
-
97
- token_queue = Queue()
98
- stream_handler = StreamHandler(token_queue)
99
 
100
- # Create or get QA chain
101
- if not hasattr(predict, 'qa_chain'):
102
- predict.qa_chain = create_qa_chain(stream_handler)
103
 
104
- # Function to process the message and add to queue
105
- def get_response():
106
- result = predict.qa_chain({"question": message})
107
- # Add sources to queue
108
- sources = "\n\nSources:\n"
109
- seen_components = set()
110
- shown_sources = 0
111
-
112
- for doc in result["source_documents"]:
113
- component = doc.metadata.get('component', '')
114
- title = doc.metadata.get('title', '')
115
- combo = (component, title)
116
-
117
- if combo not in seen_components and shown_sources < 3:
118
- seen_components.add(combo)
119
- shown_sources += 1
120
- sources += f"\nSource {shown_sources}:\n"
121
- sources += f"Title: {title}\n"
122
- sources += f"Component: {component}\n"
123
- sources += f"Content: {doc.page_content[:300]}...\n"
124
 
125
- for char in sources:
126
- token_queue.put(char)
127
- token_queue.put(None) # Signal end of response
 
 
 
 
128
 
129
- # Start processing in a separate thread
130
- thread = Thread(target=get_response)
131
- thread.start()
132
 
133
- # Stream the response
134
- response = ""
135
- while True:
136
- token = token_queue.get()
137
- if token is None:
138
- break
139
- response += token
140
- yield response
141
 
142
- # Create the Gradio interface
143
- with gr.Blocks() as demo:
144
- chatbot = gr.ChatInterface(
145
- predict,
146
- title="Apple Music API Documentation Assistant",
147
- description="Ask questions about the Apple Music API documentation. Type 'quit', 'exit', or 'bye' to clear chat history.",
148
- examples=[
149
- "How to search for songs on Apple Music API?",
150
- "What are the required parameters for searching songs?",
151
- "Show me an example request with all parameters"
152
- ]
153
- )
 
 
154
 
155
  if __name__ == "__main__":
156
- demo.queue().launch()
 
5
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
6
  from langchain.prompts import PromptTemplate
7
  from langchain_community.vectorstores import Chroma
 
 
8
 
9
+ def create_qa_chain():
 
 
 
 
 
 
 
10
  """
11
+ Create the QA chain with the loaded vectorstore
12
  """
13
+ # Initialize embeddings and load vectorstore
14
  embeddings = OpenAIEmbeddings()
15
  vectorstore = Chroma(
16
  persist_directory="./vectorstore",
17
  embedding_function=embeddings
18
  )
19
 
20
+ # Set up retriever
21
  retriever = vectorstore.as_retriever(
22
  search_type="mmr",
23
  search_kwargs={
 
27
  }
28
  )
29
 
30
+ # Set up memory
31
  memory = ConversationBufferMemory(
32
  memory_key="chat_history",
33
  return_messages=True,
34
  output_key='answer'
35
  )
36
 
37
+ # Create prompt template
38
  qa_prompt = PromptTemplate.from_template("""You are an expert technical writer specializing in API documentation.
39
  When describing API endpoints, structure your response in this exact format:
40
 
 
61
 
62
  Technical answer (following the exact structure above):""")
63
 
64
+ # Create the chain
65
  qa_chain = ConversationalRetrievalChain.from_llm(
66
  llm=ChatOpenAI(
67
  temperature=0.1,
68
+ model_name="gpt-4-turbo-preview"
 
 
69
  ),
70
  retriever=retriever,
71
  memory=memory,
 
76
 
77
  return qa_chain
78
 
79
+ def chat(message, history):
80
  """
81
+ Process chat messages and return responses
82
  """
83
+ # Get or create QA chain
84
+ if not hasattr(chat, 'qa_chain'):
85
+ chat.qa_chain = create_qa_chain()
 
 
 
 
 
 
 
86
 
87
+ # Get response
88
+ result = chat.qa_chain({"question": message})
 
89
 
90
+ # Format sources
91
+ sources = "\n\nSources:\n"
92
+ seen_components = set()
93
+ shown_sources = 0
94
+
95
+ for doc in result["source_documents"]:
96
+ component = doc.metadata.get('component', '')
97
+ title = doc.metadata.get('title', '')
98
+ combo = (component, title)
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ if combo not in seen_components and shown_sources < 3:
101
+ seen_components.add(combo)
102
+ shown_sources += 1
103
+ sources += f"\nSource {shown_sources}:\n"
104
+ sources += f"Title: {title}\n"
105
+ sources += f"Component: {component}\n"
106
+ sources += f"Content: {doc.page_content[:300]}...\n"
107
 
108
+ # Combine response with sources
109
+ full_response = result["answer"] + sources
 
110
 
111
+ return full_response
 
 
 
 
 
 
 
112
 
113
+ # Create and launch the Gradio interface
114
+ demo = gr.ChatInterface(
115
+ chat,
116
+ title="Apple Music API Documentation Assistant",
117
+ description="Ask questions about the Apple Music API documentation.",
118
+ examples=[
119
+ "How to search for songs on Apple Music API?",
120
+ "What are the required parameters for searching songs?",
121
+ "Show me an example request with all parameters"
122
+ ],
123
+ retry_btn=None,
124
+ undo_btn=None,
125
+ clear_btn="Clear Chat"
126
+ )
127
 
128
  if __name__ == "__main__":
129
+ demo.launch()