|
|
import gradio as gr |
|
|
import os |
|
|
from langchain.chains import ConversationalRetrievalChain |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain_community.vectorstores import Chroma |
|
|
|
|
|
def create_qa_chain(): |
|
|
""" |
|
|
Create the QA chain with the loaded vectorstore |
|
|
""" |
|
|
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
vectorstore = Chroma( |
|
|
persist_directory="./vectorstore", |
|
|
embedding_function=embeddings |
|
|
) |
|
|
|
|
|
|
|
|
retriever = vectorstore.as_retriever( |
|
|
search_type="mmr", |
|
|
search_kwargs={ |
|
|
"k": 6, |
|
|
"fetch_k": 20, |
|
|
"lambda_mult": 0.3, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
memory = ConversationBufferMemory( |
|
|
memory_key="chat_history", |
|
|
return_messages=True, |
|
|
output_key='answer' |
|
|
) |
|
|
|
|
|
|
|
|
qa_prompt = PromptTemplate.from_template("""You are an expert technical writer specializing in API documentation. |
|
|
When describing API endpoints, structure your response in this exact format: |
|
|
|
|
|
1. Start with the HTTP method and base URI structure |
|
|
2. List all key parameters with: |
|
|
- Parameter name in bold (**parameter**) |
|
|
- Type and requirement status |
|
|
- Clear description |
|
|
- Example values where applicable |
|
|
3. Show complete example requests with: |
|
|
- Basic example |
|
|
- Full example with all parameters |
|
|
- Headers included |
|
|
4. Include any relevant response information |
|
|
|
|
|
Use markdown formatting for: |
|
|
- Code blocks with syntax highlighting |
|
|
- Bold text for important terms |
|
|
- Clear section separation |
|
|
|
|
|
Context: {context} |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Technical answer (following the exact structure above):""") |
|
|
|
|
|
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
|
llm=ChatOpenAI( |
|
|
temperature=0.1, |
|
|
model_name="gpt-4-turbo-preview" |
|
|
), |
|
|
retriever=retriever, |
|
|
memory=memory, |
|
|
return_source_documents=True, |
|
|
combine_docs_chain_kwargs={"prompt": qa_prompt}, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
return qa_chain |
|
|
|
|
|
def chat(message, history): |
|
|
""" |
|
|
Process chat messages and return responses |
|
|
""" |
|
|
|
|
|
if not hasattr(chat, 'qa_chain'): |
|
|
chat.qa_chain = create_qa_chain() |
|
|
|
|
|
|
|
|
result = chat.qa_chain({"question": message}) |
|
|
|
|
|
|
|
|
sources = "\n\nSources:\n" |
|
|
seen_components = set() |
|
|
shown_sources = 0 |
|
|
|
|
|
for doc in result["source_documents"]: |
|
|
component = doc.metadata.get('component', '') |
|
|
title = doc.metadata.get('title', '') |
|
|
combo = (component, title) |
|
|
|
|
|
if combo not in seen_components and shown_sources < 3: |
|
|
seen_components.add(combo) |
|
|
shown_sources += 1 |
|
|
sources += f"\nSource {shown_sources}:\n" |
|
|
sources += f"Title: {title}\n" |
|
|
sources += f"Component: {component}\n" |
|
|
sources += f"Content: {doc.page_content[:300]}...\n" |
|
|
|
|
|
|
|
|
full_response = result["answer"] + sources |
|
|
|
|
|
return full_response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
chat, |
|
|
title="Apple Music API Documentation Assistant", |
|
|
description="Ask questions about the Apple Music API documentation.", |
|
|
examples=[ |
|
|
"How to search for songs on Apple Music API?", |
|
|
"What are the required parameters for searching songs?", |
|
|
"Show me an example request with all parameters" |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|