Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[1]: | |
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import os | |
| import json | |
| import requests | |
| import gradio as gr | |
| from typing import Literal, List, Dict, Any | |
| from pydantic import BaseModel, Field | |
| from dotenv import load_dotenv | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain.schema import Document | |
| from langgraph.graph import END, StateGraph | |
| from typing_extensions import TypedDict | |
| # Load environment variables | |
| load_dotenv() | |
| # Configuration | |
| BASE_URL = "https://api.llama.com/v1" | |
| LLAMA_API_KEY = os.environ.get('LLAMA_API_KEY') | |
| # Initialize global variables | |
| vectorstore = None | |
| retriever = None | |
| web_search_tool = None | |
| app = None | |
| class RouteQuery(BaseModel): | |
| """Route a user query to the most relevant datasource.""" | |
| datasource: Literal["vectorstore", "web_search"] = Field( | |
| ..., | |
| description="Given a user question choose to route it to web search or a vectorstore.", | |
| ) | |
| class GraphState(TypedDict): | |
| """Represents the state of our graph.""" | |
| question: str | |
| generation: str | |
| web_search: str | |
| documents: List[str] | |
| def initialize_system(): | |
| """Initialize the RAG system with vectorstore and workflow.""" | |
| global vectorstore, retriever, web_search_tool, app | |
| try: | |
| # Read configuration | |
| with open('wragby.json', 'r') as file: | |
| data = json.load(file) | |
| urls = data["urls"] | |
| # Build Index | |
| docs = [WebBaseLoader(url).load() for url in urls] | |
| docs_list = [item for sublist in docs for item in sublist] | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=500, chunk_overlap=0 | |
| ) | |
| doc_splits = text_splitter.split_documents(docs_list) | |
| vectorstore = Chroma.from_documents( | |
| documents=doc_splits, | |
| collection_name="rag-chroma", | |
| embedding=HuggingFaceEmbeddings(), | |
| ) | |
| retriever = vectorstore.as_retriever() | |
| # Initialize web search | |
| web_search_tool = TavilySearchResults(k=3) | |
| # Build workflow | |
| app = build_workflow() | |
| return "✅ System initialized successfully!" | |
| except Exception as e: | |
| return f"❌ Error initializing system: {str(e)}" | |
| def chat_completion(messages, model="Llama-4-Scout-17B-16E-Instruct-FP8", max_tokens=1024): | |
| """Make API call to Llama.""" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {LLAMA_API_KEY}", | |
| } | |
| payload = { | |
| "messages": messages, | |
| "model": model, | |
| "max_tokens": max_tokens, | |
| "stream": False, | |
| } | |
| response = requests.post("https://api.llama.com/v1/chat/completions", headers=headers, json=payload) | |
| return response | |
| def route_query(question: str) -> RouteQuery: | |
| """Route a user question using Llama API with structured output.""" | |
| system_message = """You are an expert at routing a user question to a vectorstore or web search. | |
| The vectorstore contains documents related to the business Wragby Solutions, their product information, and customer sales. | |
| Use the vectorstore for questions on these topics. Otherwise, use web-search. | |
| You must respond with a JSON object in this exact format: | |
| {"datasource": "vectorstore"} or {"datasource": "web_search"} | |
| Only respond with the JSON object, no additional text.""" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": question} | |
| ] | |
| try: | |
| response = chat_completion(messages, max_tokens=50) | |
| content = response.json()['completion_message']['content']['text'].strip() | |
| route_data = json.loads(content) | |
| return RouteQuery(**route_data) | |
| except Exception as e: | |
| print(f"Error parsing response: {e}") | |
| return RouteQuery(datasource="web_search") | |
| def format_docs(docs): | |
| """Format a list of documents into a single string.""" | |
| if not docs: | |
| return "" | |
| formatted_docs = [] | |
| for doc in docs: | |
| try: | |
| if hasattr(doc, 'page_content'): | |
| formatted_docs.append(doc.page_content) | |
| elif isinstance(doc, dict) and 'content' in doc: | |
| formatted_docs.append(doc['content']) | |
| elif isinstance(doc, dict) and 'page_content' in doc: | |
| formatted_docs.append(doc['page_content']) | |
| elif isinstance(doc, str): | |
| formatted_docs.append(doc) | |
| else: | |
| formatted_docs.append(str(doc)) | |
| except Exception as e: | |
| print(f"Error processing document: {e}") | |
| formatted_docs.append(str(doc)) | |
| return "\n\n".join(formatted_docs) | |
| def rag_generate_answer(question: str, docs: list) -> str: | |
| """Generate an answer using RAG.""" | |
| system_message = """You are an assistant for question-answering tasks. | |
| Use the following pieces of retrieved context to answer the question. | |
| If you don't know the answer, just say that you don't know. | |
| Use three sentences maximum and keep the answer concise.""" | |
| context = format_docs(docs) | |
| user_message = f"""Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| try: | |
| response = chat_completion(messages, max_tokens=512) | |
| answer = response.json()['completion_message']['content']['text'].strip() | |
| return answer | |
| except Exception as e: | |
| print(f"Error generating RAG answer: {e}") | |
| return "I apologize, but I encountered an error while generating an answer." | |
| def grade_answer_quality(question: str, generation: str) -> dict: | |
| """Grade whether an LLM generation addresses/resolves the user question.""" | |
| system_message = """You are a grader assessing whether an answer addresses / resolves a question. | |
| Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question. | |
| You must respond with exactly one word: | |
| - yes (if the answer addresses and resolves the question) | |
| - no (if the answer does not address or resolve the question) | |
| Only respond with 'yes' or 'no', no additional text or explanation.""" | |
| user_message = f"User question: \n\n {question} \n\n LLM generation: {generation}" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| try: | |
| response = chat_completion(messages, max_tokens=10) | |
| content = response.json()['completion_message']['content']['text'].strip().lower() | |
| if "yes" in content: | |
| score = "yes" | |
| elif "no" in content: | |
| score = "no" | |
| else: | |
| score = "no" | |
| return {"score": score} | |
| except Exception as e: | |
| print(f"Error calling Llama API for answer grading: {e}") | |
| return {"score": "no"} | |
| def grade_hallucinations(documents: list, generation: str) -> dict: | |
| """Grade whether an LLM generation is grounded in the provided documents.""" | |
| system_message = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. | |
| Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts. | |
| You must respond with exactly one word: | |
| - yes (if the generation is grounded in the facts) | |
| - no (if the generation contains hallucinations or unsupported claims) | |
| Only respond with 'yes' or 'no', no additional text or explanation.""" | |
| if isinstance(documents, list): | |
| if documents and hasattr(documents[0], 'page_content'): | |
| docs_text = format_docs(documents) | |
| else: | |
| docs_text = "\n\n".join(documents) | |
| else: | |
| docs_text = str(documents) | |
| user_message = f"Set of facts: \n\n {docs_text} \n\n LLM generation: {generation}" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| try: | |
| response = chat_completion(messages, max_tokens=10) | |
| content = response.json()['completion_message']['content']['text'].strip().lower() | |
| if "yes" in content: | |
| score = "yes" | |
| elif "no" in content: | |
| score = "no" | |
| else: | |
| score = "no" | |
| return {"score": score} | |
| except Exception as e: | |
| print(f"Error calling Llama API for hallucination grading: {e}") | |
| return {"score": "no"} | |
| def grade_document_relevance(question: str, document: str) -> dict: | |
| """Grade the relevance of a retrieved document to a user question.""" | |
| system_message = """You are a grader assessing relevance of a retrieved document to a user question. | |
| If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. | |
| It does not need to be a stringent test. The goal is to filter out erroneous retrievals. | |
| Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. | |
| You must respond with exactly one word: | |
| - yes (if document is relevant) | |
| - no (if document is not relevant) | |
| Only respond with 'yes' or 'no', no additional text or explanation.""" | |
| user_message = f"Retrieved document: \n\n {document} \n\n User question: {question}" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| try: | |
| response = chat_completion(messages) | |
| content = response.json()['completion_message']['content']['text'].strip().lower() | |
| if "yes" in content: | |
| score = "yes" | |
| elif "no" in content: | |
| score = "no" | |
| else: | |
| score = "no" | |
| return {"score": score} | |
| except Exception as e: | |
| print(f"Error calling Llama API for document grading: {e}") | |
| return {"score": "no"} | |
| # Workflow Nodes | |
| def retrieve(state): | |
| """Retrieve documents from vectorstore""" | |
| print("---RETRIEVE---") | |
| question = state["question"] | |
| documents = retriever.invoke(question) | |
| return {"documents": documents, "question": question} | |
| def generate(state): | |
| """Generate answer using RAG on retrieved documents""" | |
| print("---GENERATE---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = rag_generate_answer(question, documents) | |
| return {"documents": documents, "question": question, "generation": generation} | |
| def grade_documents(state): | |
| """Determines whether the retrieved documents are relevant to the question""" | |
| print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| filtered_docs = [] | |
| web_search = "No" | |
| for d in documents: | |
| score = grade_document_relevance(question, d.page_content) | |
| grade = score["score"] | |
| if grade.lower() == "yes": | |
| print("---GRADE: DOCUMENT RELEVANT---") | |
| filtered_docs.append(d) | |
| else: | |
| print("---GRADE: DOCUMENT NOT RELEVANT---") | |
| web_search = "Yes" | |
| continue | |
| return {"documents": filtered_docs, "question": question, "web_search": web_search} | |
| def web_search(state): | |
| """Web search based on the question""" | |
| print("---WEB SEARCH---", state) | |
| question = state["question"] | |
| documents = state.get("documents") | |
| docs = web_search_tool.invoke({"query": question}) | |
| web_results = "\n".join([d["content"] for d in docs]) | |
| web_results = Document(page_content=web_results) | |
| if documents is not None: | |
| documents.append(web_results) | |
| else: | |
| documents = [web_results] | |
| return {"documents": documents, "question": question} | |
| def route_question(state): | |
| """Route question to web search or RAG.""" | |
| print("---ROUTE QUESTION---") | |
| question = state["question"] | |
| source = route_query(question) | |
| if source.datasource == 'web_search': | |
| print("---ROUTE QUESTION TO WEB SEARCH---") | |
| return "websearch" | |
| elif source.datasource == 'vectorstore': | |
| print("---ROUTE QUESTION TO RAG---") | |
| return "vectorstore" | |
| def decide_to_generate(state): | |
| """Determines whether to generate an answer, or add web search""" | |
| print("---ASSESS GRADED DOCUMENTS---") | |
| web_search = state["web_search"] | |
| if web_search == "Yes": | |
| print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---") | |
| return "websearch" | |
| else: | |
| print("---DECISION: GENERATE---") | |
| return "generate" | |
| def grade_generation_v_documents_and_question(state): | |
| """Determines whether the generation is grounded in the document and answers question.""" | |
| print("---CHECK HALLUCINATIONS---") | |
| question = state["question"] | |
| documents = state["documents"] | |
| generation = state["generation"] | |
| score = grade_hallucinations(documents, generation) | |
| grade = score["score"] | |
| if grade == "yes": | |
| print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
| print("---GRADE GENERATION vs QUESTION---") | |
| score = grade_answer_quality(question, generation) | |
| grade = score["score"] | |
| if grade == "yes": | |
| print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
| return "useful" | |
| else: | |
| print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
| return "not useful" | |
| else: | |
| print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
| return "not supported" | |
| def build_workflow(): | |
| """Build the RAG workflow graph.""" | |
| workflow = StateGraph(GraphState) | |
| # Define the nodes | |
| workflow.add_node("websearch", web_search) | |
| workflow.add_node("retrieve", retrieve) | |
| workflow.add_node("grade_documents", grade_documents) | |
| workflow.add_node("generate", generate) | |
| # Build graph | |
| workflow.set_conditional_entry_point( | |
| route_question, | |
| { | |
| "websearch": "websearch", | |
| "vectorstore": "retrieve", | |
| }, | |
| ) | |
| workflow.add_edge("retrieve", "grade_documents") | |
| workflow.add_conditional_edges( | |
| "grade_documents", | |
| decide_to_generate, | |
| { | |
| "websearch": "websearch", | |
| "generate": "generate", | |
| }, | |
| ) | |
| workflow.add_edge("websearch", "generate") | |
| workflow.add_conditional_edges( | |
| "generate", | |
| grade_generation_v_documents_and_question, | |
| { | |
| "not supported": "generate", | |
| "useful": END, | |
| "not useful": "websearch", | |
| }, | |
| ) | |
| return workflow.compile().with_config({"run_name": "Wragby Solutions Assistant"}) | |
| def process_question(question: str, history: List[List[str]]) -> tuple: | |
| """Process a question through the RAG system and return the answer with sources.""" | |
| if not question.strip(): | |
| return history, "Please enter a question." | |
| if app is None: | |
| return history, "❌ System not initialized. Please click 'Initialize System' first." | |
| try: | |
| # Process through the workflow | |
| inputs = {"question": question} | |
| final_state = None | |
| for output in app.stream(inputs): | |
| for key, value in output.items(): | |
| print(f"Finished running: {key}") | |
| final_state = value | |
| if final_state and "generation" in final_state: | |
| answer = final_state["generation"] | |
| # Get source information | |
| sources = [] | |
| if "documents" in final_state and final_state["documents"]: | |
| for i, doc in enumerate(final_state["documents"][:3]): # Show top 3 sources | |
| if hasattr(doc, 'metadata') and 'source' in doc.metadata: | |
| sources.append(f"Source {i+1}: {doc.metadata['source']}") | |
| else: | |
| sources.append(f"Source {i+1}: Retrieved document") | |
| # Format response with sources | |
| if sources: | |
| full_response = f"{answer}\n\n**Sources:**\n" + "\n".join(sources) | |
| else: | |
| full_response = answer | |
| # Update chat history | |
| history.append([question, full_response]) | |
| return history, "" | |
| else: | |
| history.append([question, "I apologize, but I couldn't generate an answer for your question."]) | |
| return history, "" | |
| except Exception as e: | |
| error_msg = f"❌ Error processing question: {str(e)}" | |
| history.append([question, error_msg]) | |
| return history, "" | |
| def clear_chat(): | |
| """Clear the chat history.""" | |
| return [], "" | |
| # Create Gradio Interface | |
| def create_gradio_app(): | |
| """Create and configure the Gradio interface.""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| .chat-container { | |
| height: 500px !important; | |
| } | |
| .title { | |
| text-align: center; | |
| color: #2D5AA0; | |
| margin-bottom: 20px; | |
| } | |
| .description { | |
| text-align: center; | |
| color: #666; | |
| margin-bottom: 30px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="Wragby Solutions Q&A Assistant") as demo: | |
| gr.HTML(""" | |
| <div class="title"> | |
| <h1>🤖 Wragby Solutions Q&A Assistant</h1> | |
| </div> | |
| <div class="description"> | |
| <p>Ask questions about Wragby Solutions products, services, and business information. | |
| The system will search through company documents and the web to provide accurate answers.</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="Chat History", | |
| height=500, | |
| show_label=True, | |
| container=True, | |
| elem_classes=["chat-container"] | |
| ) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| placeholder="Ask a question about Wragby Solutions...", | |
| label="Your Question", | |
| lines=2, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| with gr.Column(scale=1): | |
| # System controls | |
| gr.HTML("<h3>System Controls</h3>") | |
| init_btn = gr.Button("Initialize System", variant="primary") | |
| init_status = gr.Textbox( | |
| label="System Status", | |
| value="Click 'Initialize System' to start", | |
| interactive=False | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 5px;"> | |
| <h4>💡 Sample Questions:</h4> | |
| <ul> | |
| <li>What are the types of solutions offered by Wbizmanager?</li> | |
| <li>How can SMBs use Wbizmanager?</li> | |
| <li>What SAP solutions are available from Wragby?</li> | |
| <li>Tell me about Wragby Solutions services</li> | |
| </ul> | |
| </div> | |
| """) | |
| # Event handlers | |
| init_btn.click( | |
| fn=initialize_system, | |
| outputs=[init_status] | |
| ) | |
| submit_btn.click( | |
| fn=process_question, | |
| inputs=[question_input, chatbot], | |
| outputs=[chatbot, question_input] | |
| ) | |
| question_input.submit( | |
| fn=process_question, | |
| inputs=[question_input, chatbot], | |
| outputs=[chatbot, question_input] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| outputs=[chatbot, question_input] | |
| ) | |
| return demo | |
| # In[2]: | |
| # Create and launch the Gradio app | |
| demo = create_gradio_app() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, # Set to True if you want to create a public link | |
| debug=True | |
| ) | |