Spaces:
Sleeping
Sleeping
Upload 33 files
Browse files- Dockerfile +37 -0
- requirements.txt +26 -0
- seed_db.py +24 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/router.cpython-312.pyc +0 -0
- src/__pycache__/schemas.cpython-312.pyc +0 -0
- src/agent/__pycache__/graph.cpython-312.pyc +0 -0
- src/agent/__pycache__/tools.cpython-312.pyc +0 -0
- src/agent/graph.py +125 -0
- src/agent/tools.py +59 -0
- src/api/__pycache__/cache.cpython-312.pyc +0 -0
- src/api/__pycache__/server.cpython-312.pyc +0 -0
- src/api/cache.py +41 -0
- src/api/server.py +77 -0
- src/evaluation/judge.py +77 -0
- src/evaluation/run_evals.py +53 -0
- src/providers/__init__.py +0 -0
- src/providers/__pycache__/__init__.cpython-312.pyc +0 -0
- src/providers/__pycache__/anthropic_client.cpython-312.pyc +0 -0
- src/providers/__pycache__/base.cpython-312.pyc +0 -0
- src/providers/__pycache__/openai_client.cpython-312.pyc +0 -0
- src/providers/anthropic_client.py +49 -0
- src/providers/base.py +43 -0
- src/providers/openai_client.py +75 -0
- src/rag/__pycache__/chatbot.cpython-312.pyc +0 -0
- src/rag/__pycache__/ingestion.cpython-312.pyc +0 -0
- src/rag/__pycache__/vector_store.cpython-312.pyc +0 -0
- src/rag/chatbot.py +63 -0
- src/rag/ingestion.py +41 -0
- src/rag/vector_store.py +61 -0
- src/router.py +87 -0
- src/schemas.py +61 -0
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 1. The Foundation
|
| 2 |
+
# We start with a lightweight, official Linux image with Python 3.12 pre-installed.
|
| 3 |
+
FROM python:3.12-slim
|
| 4 |
+
|
| 5 |
+
# 2. Environment variables
|
| 6 |
+
# Prevent python from writing messy .pyc files
|
| 7 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 8 |
+
# Ensure our terminal print() statements show up immediately in cloud logs
|
| 9 |
+
ENV PYTHONUNBUFFERED=1
|
| 10 |
+
# Tell HuggingFaceexactly where to save its 100MB math model
|
| 11 |
+
ENV HF_HOME=/app/.cache/huggingface
|
| 12 |
+
|
| 13 |
+
# 3. The Workspace
|
| 14 |
+
# Create a folder inside the container called /app and move inside it
|
| 15 |
+
WORKDIR /app
|
| 16 |
+
|
| 17 |
+
# 4. Cache optimization (the Architect's Trick)
|
| 18 |
+
# We ONLY copy the requirements file first.
|
| 19 |
+
# Docker caches steps. If you change your Python code later, Docker won't
|
| 20 |
+
# force you to sit through a 5-minute re-installation of Pandas and LangChain!
|
| 21 |
+
COPY requirements.txt .
|
| 22 |
+
|
| 23 |
+
# Install the Python packages
|
| 24 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# 5. COPY the payload
|
| 27 |
+
# Now we copy the rest of your actual code into the container
|
| 28 |
+
COPY . .
|
| 29 |
+
|
| 30 |
+
# 6. OPEN the gate
|
| 31 |
+
# Tell the container to allow trafic on port 8000
|
| 32 |
+
EXPOSE 8000
|
| 33 |
+
|
| 34 |
+
# 7. The ignition switch
|
| 35 |
+
# The exact terminal command the container runs when it wakes up in the cloud.
|
| 36 |
+
# Notice we use 0.0.0.0 so the cloud provider's router can find it.
|
| 37 |
+
CMD ["uvicorn", "src.api.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pydantic
|
| 2 |
+
openai
|
| 3 |
+
anthropic
|
| 4 |
+
langchain-openai
|
| 5 |
+
|
| 6 |
+
langchain
|
| 7 |
+
chromadb
|
| 8 |
+
tiktoken
|
| 9 |
+
langchain-text-splitters
|
| 10 |
+
langchain-core
|
| 11 |
+
langchain-community
|
| 12 |
+
langgraph
|
| 13 |
+
|
| 14 |
+
setuptools
|
| 15 |
+
|
| 16 |
+
# Free tier...
|
| 17 |
+
langchain-groq
|
| 18 |
+
langchain-huggingface
|
| 19 |
+
sentence-transformers
|
| 20 |
+
|
| 21 |
+
python-dotenv
|
| 22 |
+
|
| 23 |
+
fastapi
|
| 24 |
+
uvicorn
|
| 25 |
+
|
| 26 |
+
requests
|
seed_db.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.rag.vector_store import build_vector_store
|
| 2 |
+
from langchain_core.documents import Document
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
api_key = os.getenv("HF_TOKEN")
|
| 6 |
+
|
| 7 |
+
def seed_database():
|
| 8 |
+
print("Seeding new HuggingFace database...")
|
| 9 |
+
|
| 10 |
+
# 1. Our dummy text
|
| 11 |
+
sample_text = (
|
| 12 |
+
"OmniRouter is an enterprise-grade AI architecture combining high-concurrency "
|
| 13 |
+
"LLM routing and local Vector Database retrieval. If the primary API fails, "
|
| 14 |
+
"it seamlessly switches to a fallback model. It uses LangGraph for agentic reasoning."
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# 2. Package it as a chunk
|
| 18 |
+
doc = Document(page_content=sample_text, metadata={"source": "manual.pdf"})
|
| 19 |
+
|
| 20 |
+
# 3. Build and save the DB
|
| 21 |
+
build_vector_store([doc], api_key=api_key)
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
seed_database()
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (126 Bytes). View file
|
|
|
src/__pycache__/router.cpython-312.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
src/__pycache__/schemas.cpython-312.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
src/agent/__pycache__/graph.cpython-312.pyc
ADDED
|
Binary file (3.42 kB). View file
|
|
|
src/agent/__pycache__/tools.cpython-312.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
src/agent/graph.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Annotated, TypedDict
|
| 3 |
+
# # OpenAI...
|
| 4 |
+
# from langchain_openai import ChatOpenAI
|
| 5 |
+
# Groq LLM...
|
| 6 |
+
from langchain_groq import ChatGroq
|
| 7 |
+
from langchain_core.messages import BaseMessage
|
| 8 |
+
from langgraph.graph import StateGraph, START, END
|
| 9 |
+
from langgraph.graph.message import add_messages
|
| 10 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 11 |
+
from src.agent.tools import search_documentation
|
| 12 |
+
# Import our custom tool
|
| 13 |
+
from src.agent.tools import search_documentation
|
| 14 |
+
# Fix the infinite loop using system prompt(Guardrail)
|
| 15 |
+
from langchain_core.messages import SystemMessage
|
| 16 |
+
# Human in the loop using LangGraphs checkpointers...
|
| 17 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
load_dotenv()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# 1. UPGRADED STATE
|
| 25 |
+
class AgentState(TypedDict):
|
| 26 |
+
# 'add_messages' ensures we append to the history, never overwrite it.
|
| 27 |
+
messages: Annotated[list[BaseMessage], add_messages]
|
| 28 |
+
|
| 29 |
+
# 2. INITIALIZE THE BRAIN
|
| 30 |
+
# We instantiate the LLM and "bind" our tool to it.
|
| 31 |
+
|
| 32 |
+
# # Make sure you export GROQ_API_KEY in your terminal before running!
|
| 33 |
+
# os.environ["GROQ_API_KEY"] = "gsk_jd"
|
| 34 |
+
# We use Meta's Llama 3 8B model hosted on Groq for incredible speed
|
| 35 |
+
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0)
|
| 36 |
+
|
| 37 |
+
# # This sends the JSON schema we looked at yesterday directly to OpenAI.
|
| 38 |
+
# llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
|
| 39 |
+
|
| 40 |
+
tools = [search_documentation]
|
| 41 |
+
llm_with_tools = llm.bind_tools(tools)
|
| 42 |
+
|
| 43 |
+
# 3. THE NODES
|
| 44 |
+
"""
|
| 45 |
+
Chat node for seamless interaction with the LLM model...
|
| 46 |
+
"""
|
| 47 |
+
def chatbot_node(state: AgentState):
|
| 48 |
+
"""
|
| 49 |
+
This node intercepts the history, adds strict behavioral rules,
|
| 50 |
+
and then passes it to the LLM.
|
| 51 |
+
|
| 52 |
+
The LLM will either return a standard text message, OR a special "ToolCall" message.
|
| 53 |
+
"""
|
| 54 |
+
print("\n--- [NODE: Chatbot] Thinking... ---")
|
| 55 |
+
|
| 56 |
+
# 1. The Circuit Breaker System Prompt --> Prompt based Flow control
|
| 57 |
+
system_message = SystemMessage(content=(
|
| 58 |
+
"You are an elite AI Engineering assistant. "
|
| 59 |
+
"You have access to a search_documentation tool. "
|
| 60 |
+
"CRITICAL RULE: If you use the tool and it returns 'No relevant information found', "
|
| 61 |
+
"you MUST NOT use the tool again. Immediately stop and tell the user "
|
| 62 |
+
"'I do not have enough information to answer that based on the documentation.' "
|
| 63 |
+
"Do not guess. Do not hallucinate."
|
| 64 |
+
))
|
| 65 |
+
|
| 66 |
+
# 2. Prepend the system rules to the chat history
|
| 67 |
+
messages_to_send = [system_message] + state["messages"]
|
| 68 |
+
|
| 69 |
+
# 3. Invoke the LLM with the strict rules applied
|
| 70 |
+
response = llm_with_tools.invoke(messages_to_send)
|
| 71 |
+
|
| 72 |
+
# We return the message wrapped in a list to trigger the 'add_messages' append behavior
|
| 73 |
+
return {"messages": [response]}
|
| 74 |
+
|
| 75 |
+
# LangGraph has a built-in node specifically for executing tools!
|
| 76 |
+
# It reads the "ToolCall" message, runs our Python function, and returns a "ToolMessage".
|
| 77 |
+
tool_node = ToolNode(tools=tools)
|
| 78 |
+
|
| 79 |
+
# =============================================
|
| 80 |
+
# 4. COMPILE THE GRAPH with human in the loop...
|
| 81 |
+
# =============================================
|
| 82 |
+
workflow = StateGraph(AgentState)
|
| 83 |
+
|
| 84 |
+
# Add our two worker nodes
|
| 85 |
+
workflow.add_node("chatbot", chatbot_node)
|
| 86 |
+
workflow.add_node("tools", tool_node)
|
| 87 |
+
|
| 88 |
+
# Set the entry point
|
| 89 |
+
workflow.add_edge(START, "chatbot")
|
| 90 |
+
# 5. THE MAGIC ROUTING
|
| 91 |
+
# 'tools_condition' is a built-in LangGraph edge.
|
| 92 |
+
# It looks at the last message from the chatbot.
|
| 93 |
+
# If it has a tool call, it routes to "tools". If it's just text, it routes to END.
|
| 94 |
+
workflow.add_conditional_edges("chatbot", tools_condition)
|
| 95 |
+
# After a tool finishes running, ALWAYS loop back to the chatbot
|
| 96 |
+
# so it can read the database results and formulate a final answer!
|
| 97 |
+
workflow.add_edge("tools", "chatbot")
|
| 98 |
+
|
| 99 |
+
# Initialize the short-term memory vault
|
| 100 |
+
memory = MemorySaver()
|
| 101 |
+
|
| 102 |
+
# Compile with the memory and the breakpoint!
|
| 103 |
+
app = workflow.compile(
|
| 104 |
+
checkpointer=memory,
|
| 105 |
+
interrupt_before=["tools"] # tell the graph to pause before executing this node.
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
from langchain_core.messages import HumanMessage
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
# Ensure your API key is available
|
| 113 |
+
# os.environ["OPENAI_API_KEY"] = "YOUR_REAL_OPENAI_API_KEY"
|
| 114 |
+
|
| 115 |
+
print("========== AGENT TEST ==========")
|
| 116 |
+
initial_state = {
|
| 117 |
+
"messages": [HumanMessage(content="What does OmniRouter do?")]
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
# stream() allows us to see the exact output of each node as it executes!
|
| 121 |
+
for event in app.stream(initial_state):
|
| 122 |
+
for node_name, node_state in event.items():
|
| 123 |
+
print(f"Update from node '{node_name}':")
|
| 124 |
+
# Print the content of the very last message added to the state
|
| 125 |
+
print(f" -> {node_state['messages'][-1].content}\n")
|
src/agent/tools.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Let's officially give your agent its brain.
|
| 3 |
+
|
| 4 |
+
We are going to use LangChain's @tool decorator.
|
| 5 |
+
This magical little wrapper takes your standard Python function,
|
| 6 |
+
reads the type hints (like query: str), reads the docstring, and automatically
|
| 7 |
+
translates the entire thing into a strict JSON schema that OpenAI and Anthropic natively understand .
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
from langchain_core.tools import tool
|
| 11 |
+
from src.rag.vector_store import get_vector_store
|
| 12 |
+
|
| 13 |
+
# The @tool decorator converts this Python function into an LLM-readable JSON schema
|
| 14 |
+
@tool
|
| 15 |
+
def search_documentation(query: str) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Searches the internal engineering documentation for information about the OmniRouter,
|
| 18 |
+
LangChain, LangGraph, or general AI engineering concepts.
|
| 19 |
+
|
| 20 |
+
Use this tool WHENEVER the user asks a technical question about how the system works,
|
| 21 |
+
fallback protocols, or specific coding architecture. Do NOT use this for general greetings.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
query: The specific search term to look up in the database.
|
| 25 |
+
It should be a standalone, highly descriptive phrase.
|
| 26 |
+
"""
|
| 27 |
+
print(f"\n--- [TOOL EXECUTION] Searching Vector DB for: '{query}' ---")
|
| 28 |
+
|
| 29 |
+
# In production, ensure your API key is loaded securely
|
| 30 |
+
api_key = os.getenv("HF_TOKEN")
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
db = get_vector_store(api_key)
|
| 34 |
+
|
| 35 |
+
# Retrieve the top 2 most relevant chunks
|
| 36 |
+
results = db.similarity_search(query, k=2)
|
| 37 |
+
|
| 38 |
+
if not results:
|
| 39 |
+
return "No relevant information found in the documentation."
|
| 40 |
+
|
| 41 |
+
# We must return a STRING, not a list of objects, so the LLM can read it easily
|
| 42 |
+
combined_text = "\n\n".join([doc.page_content for doc in results])
|
| 43 |
+
return combined_text
|
| 44 |
+
|
| 45 |
+
except Exception as e:
|
| 46 |
+
# FIXED: Print the error to the server terminal so we can see it!
|
| 47 |
+
print(f"\n🚨 [TOOL CRASHED]: {str(e)}")
|
| 48 |
+
return f"Error executing search: {str(e)}"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
# Print the name the LLM sees
|
| 53 |
+
print(f"Tool Name: {search_documentation.name}")
|
| 54 |
+
|
| 55 |
+
# Print the description the LLM reads to make its decision
|
| 56 |
+
print(f"\nTool Description: \n{search_documentation.description}")
|
| 57 |
+
|
| 58 |
+
# Print the strict JSON schema the LLM must follow to use the tool
|
| 59 |
+
print(f"\nTool Arguments Schema: \n{search_documentation.args}")
|
src/api/__pycache__/cache.cpython-312.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
src/api/__pycache__/server.cpython-312.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
src/api/cache.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
from langchain_community.vectorstores import Chroma
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
|
| 5 |
+
CACHE_DIR = "./semantic_cache_db"
|
| 6 |
+
|
| 7 |
+
def get_cache_db():
|
| 8 |
+
"""Initializes the Cache Database using free local embeddings."""
|
| 9 |
+
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 10 |
+
return Chroma(persist_directory=CACHE_DIR, embedding_function=embeddings)
|
| 11 |
+
|
| 12 |
+
def check_cache(query: str, threshold: float = 0.5) -> str | None:
|
| 13 |
+
"""
|
| 14 |
+
Embeds the user's question and mathematically checks if anyone
|
| 15 |
+
has asked a highly similar question before.
|
| 16 |
+
"""
|
| 17 |
+
db = get_cache_db()
|
| 18 |
+
|
| 19 |
+
# We search the cache and ask for the 'distance score'
|
| 20 |
+
results = db.similarity_search_with_score(query, k=1)
|
| 21 |
+
|
| 22 |
+
if results:
|
| 23 |
+
doc, score = results[0]
|
| 24 |
+
# In ChromaDB's default math (L2 distance), a LOWER score means it's MORE similar.
|
| 25 |
+
# 0.0 is an exact match. 0.5 means "very similar meaning".
|
| 26 |
+
if score < threshold:
|
| 27 |
+
print(f"\n🟢 [CACHE HIT] Similar question found! (Score: {score:.3f})")
|
| 28 |
+
return doc.metadata.get("answer")
|
| 29 |
+
|
| 30 |
+
print("\n🔴 [CACHE MISS] Question is new.")
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
def save_to_cache(query: str, answer: str):
|
| 34 |
+
"""Saves a brand new question and its answer into the database."""
|
| 35 |
+
db = get_cache_db()
|
| 36 |
+
|
| 37 |
+
# The 'page_content' is the question. The 'metadata' holds the answer.
|
| 38 |
+
doc = Document(page_content=query, metadata={"answer": answer})
|
| 39 |
+
db.add_documents([doc])
|
| 40 |
+
|
| 41 |
+
print("\n💾 [CACHE SAVED] New interaction stored for future users.")
|
src/api/server.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import asyncio
|
| 3 |
+
from fastapi import FastAPI #, HTTPException
|
| 4 |
+
from fastapi.responses import StreamingResponse
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import List
|
| 7 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 8 |
+
|
| 9 |
+
# Import our compiled LangGraph agent
|
| 10 |
+
from src.agent.graph import app as agent_app
|
| 11 |
+
from src.api.cache import check_cache, save_to_cache
|
| 12 |
+
|
| 13 |
+
# 1. Initialize the FastAPI Server
|
| 14 |
+
app = FastAPI(
|
| 15 |
+
title="OmniRouter Streaming API Agent",
|
| 16 |
+
description="Enterprise RAG Agent powered by LangGraph and FastAPI",
|
| 17 |
+
version="1.0.1"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# 2. Define our Request Schema using Pydantic
|
| 21 |
+
class ChatRequest(BaseModel):
|
| 22 |
+
query: str
|
| 23 |
+
|
| 24 |
+
async def stream_generator(query: str):
|
| 25 |
+
# ==========================================
|
| 26 |
+
# 1. THE CACHE LAYER (Lightning Fast)
|
| 27 |
+
# ==========================================
|
| 28 |
+
cached_answer = check_cache(query)
|
| 29 |
+
|
| 30 |
+
if cached_answer:
|
| 31 |
+
# We chop the cached string into words and stream them instantly
|
| 32 |
+
for word in cached_answer.split(" "):
|
| 33 |
+
yield f"data: {json.dumps({'token': word + ' '})}\n\n"
|
| 34 |
+
# We add a tiny 20ms sleep just to preserve the "typewriter" feel for the user
|
| 35 |
+
await asyncio.sleep(0.02)
|
| 36 |
+
return # EXIT EARLY! The LLM is never triggered.
|
| 37 |
+
|
| 38 |
+
"""
|
| 39 |
+
An async generator that yields tokens from the LangGraph agent
|
| 40 |
+
in a format compatible with Server-Sent Events (SSE).
|
| 41 |
+
"""
|
| 42 |
+
# ==========================================
|
| 43 |
+
# 2. THE AGENT LAYER (Heavy Compute)
|
| 44 |
+
# ==========================================
|
| 45 |
+
initial_state = {"messages": [HumanMessage(content=query)]}
|
| 46 |
+
full_answer = "" # We need to collect the tokens to save them later
|
| 47 |
+
|
| 48 |
+
# .astream_events is the key to deep-access streaming in LangChain/LangGraph
|
| 49 |
+
async for event in agent_app.astream_events(initial_state, version="v1"):
|
| 50 |
+
kind = event["event"]
|
| 51 |
+
|
| 52 |
+
# We are looking for the 'on_chat_model_stream' event
|
| 53 |
+
# This triggers every time a new token is generated by the LLM
|
| 54 |
+
if kind == "on_chat_model_stream":
|
| 55 |
+
content = event["data"]["chunk"].content
|
| 56 |
+
if content:
|
| 57 |
+
full_answer += content
|
| 58 |
+
# SSE format requires the "data: " prefix
|
| 59 |
+
yield f"data: {json.dumps({'token': content})}\n\n"
|
| 60 |
+
|
| 61 |
+
# ==========================================
|
| 62 |
+
# 3. SAVE FOR THE FUTURE
|
| 63 |
+
# ==========================================
|
| 64 |
+
# Only cache if we got an answer, AND the answer isn't our fallback failure phrase
|
| 65 |
+
failure_phrase = "I do not have enough information"
|
| 66 |
+
|
| 67 |
+
if full_answer and failure_phrase not in full_answer:
|
| 68 |
+
save_to_cache(query, full_answer)
|
| 69 |
+
else:
|
| 70 |
+
print("\n⚠️ [CACHE SKIP] Agent failed to answer. Did not poison the cache.")
|
| 71 |
+
|
| 72 |
+
@app.post("/chat/stream")
|
| 73 |
+
async def chat_streaming_endpoint(request: ChatRequest):
|
| 74 |
+
return StreamingResponse(
|
| 75 |
+
stream_generator(request.query),
|
| 76 |
+
media_type="text/event-stream"
|
| 77 |
+
)
|
src/evaluation/judge.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
We are going to create a strict grading script using LangChain's
|
| 3 |
+
**with_structured_output**. This forces our Judge LLM to return a strict JSON
|
| 4 |
+
object containing an integer score (1 for Pass, 0 for Fail) and a reasoning string.
|
| 5 |
+
"""
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
from langchain_groq import ChatGroq
|
| 8 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
+
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
# ==========================================
|
| 14 |
+
# 1. The Strict Grading Schema
|
| 15 |
+
# ==========================================
|
| 16 |
+
class HallucinationScore(BaseModel):
|
| 17 |
+
score: int = Field(description="Return 1 if perfectly grounded. Return 0 if hallucinated.")
|
| 18 |
+
reasoning: str = Field(description="A 1-sentence explanation of why you gave this score.")
|
| 19 |
+
|
| 20 |
+
# ==========================================
|
| 21 |
+
# 2. Initialize the Impartial Judge
|
| 22 |
+
# ==========================================
|
| 23 |
+
# We use temperature=0 because we want strict, deterministic grading, not creativity!
|
| 24 |
+
model_name_1 = "llama-3.1-70b-versatile"
|
| 25 |
+
model_name_2 = "llama-3.1-8b-instant"
|
| 26 |
+
|
| 27 |
+
judge_llm = ChatGroq(model=model_name_2, temperature=0)
|
| 28 |
+
structured_judge = judge_llm.with_structured_output(HallucinationScore)
|
| 29 |
+
|
| 30 |
+
# ==========================================
|
| 31 |
+
# 3. The Grading Rubric (System Prompt)
|
| 32 |
+
# ==========================================
|
| 33 |
+
system_prompt = """You are an impartial AI Compliance Judge evaluating an Agent's response.
|
| 34 |
+
You will be given the 'Retrieved Context' from the database, and the 'Agent Answer'.
|
| 35 |
+
Your ONLY job is to check for HALLUCINATIONS.
|
| 36 |
+
|
| 37 |
+
RULE:
|
| 38 |
+
- If the Agent's answer contains ANY factual information, names, or numbers that are NOT present in the Retrieved Context, score it a 0.
|
| 39 |
+
- If the Agent's answer is strictly based ONLY on the context, score it a 1.
|
| 40 |
+
- Do not grade grammar or tone. Only grade factual grounding.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 44 |
+
("system", system_prompt),
|
| 45 |
+
("human", "Retrieved Context: \n\n {context} \n\n Agent Answer: \n\n {answer}")
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
evaluator = prompt | structured_judge
|
| 49 |
+
|
| 50 |
+
def check_hallucination(context: str, answer: str):
|
| 51 |
+
print("\n⚖️ [JUDGE] Evaluating answer for hallucinations...")
|
| 52 |
+
try:
|
| 53 |
+
result = evaluator.invoke({"context": context, "answer": answer})
|
| 54 |
+
return result
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"Judge Error: {e}")
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
# The reality: What our Vector DB actually found.
|
| 62 |
+
simulated_context = (
|
| 63 |
+
"OmniRouter is an AI architecture that routes LLM requests. "
|
| 64 |
+
"It supports OpenAI and Anthropic APIs."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
print("\n========== TEST 1: The Good Agent ==========")
|
| 68 |
+
good_answer = "OmniRouter routes requests and works with Anthropic and OpenAI."
|
| 69 |
+
good_result = check_hallucination(simulated_context, good_answer)
|
| 70 |
+
print(f"Score: {good_result.score}/1")
|
| 71 |
+
print(f"Reasoning: {good_result.reasoning}")
|
| 72 |
+
|
| 73 |
+
print("\n========== TEST 2: The Hallucinating Agent ==========")
|
| 74 |
+
bad_answer = "OmniRouter routes requests and works with OpenAI, Anthropic, and Google Gemini."
|
| 75 |
+
bad_result = check_hallucination(simulated_context, bad_answer)
|
| 76 |
+
print(f"Score: {bad_result.score}/1")
|
| 77 |
+
print(f"Reasoning: {bad_result.reasoning}")
|
src/evaluation/run_evals.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
|
| 2 |
+
|
| 3 |
+
# Import your real compiled agent and your real judge
|
| 4 |
+
from src.agent.graph import app
|
| 5 |
+
from src.evaluation.judge import check_hallucination
|
| 6 |
+
|
| 7 |
+
def evaluate_real_agent(query: str):
|
| 8 |
+
print(f"\n==================================================")
|
| 9 |
+
print(f"🚀 RUNNING REAL AGENT EVALUATION")
|
| 10 |
+
print(f"Query: '{query}'")
|
| 11 |
+
print(f"==================================================")
|
| 12 |
+
|
| 13 |
+
# 1. Trigger the Real Agent
|
| 14 |
+
initial_state = {"messages": [HumanMessage(content=query)]}
|
| 15 |
+
config = {"configurable": {"thread_id": "automated_eval_run_1"}}
|
| 16 |
+
|
| 17 |
+
print("\n🤖 Agent is thinking and searching...")
|
| 18 |
+
# We use .invoke() here because we don't need streaming for a backend test
|
| 19 |
+
final_state = app.invoke(initial_state, config)
|
| 20 |
+
|
| 21 |
+
# 2. Extract the Dynamic Data from the State Machine's Memory
|
| 22 |
+
retrieved_context = ""
|
| 23 |
+
final_answer = ""
|
| 24 |
+
|
| 25 |
+
for msg in final_state["messages"]:
|
| 26 |
+
# Find the exact text the ChromaDB tool returned
|
| 27 |
+
if isinstance(msg, ToolMessage):
|
| 28 |
+
retrieved_context += msg.content + "\n"
|
| 29 |
+
# Find the final answer the Agent generated
|
| 30 |
+
elif isinstance(msg, AIMessage) and msg.content:
|
| 31 |
+
final_answer = msg.content
|
| 32 |
+
|
| 33 |
+
if not retrieved_context:
|
| 34 |
+
print("⚠️ Agent did not use the database. Cannot run Hallucination check.")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
# 3. Pass the dynamic data to the Judge
|
| 38 |
+
result = check_hallucination(context=retrieved_context, answer=final_answer)
|
| 39 |
+
|
| 40 |
+
# 4. Print the final Evaluation Report
|
| 41 |
+
print(f"\n📊 EVALUATION REPORT")
|
| 42 |
+
print(f"Score: {result.score} / 1")
|
| 43 |
+
if result.score == 1:
|
| 44 |
+
print("✅ PASS: Answer is completely grounded in the database.")
|
| 45 |
+
else:
|
| 46 |
+
print("❌ FAIL: Hallucination detected!")
|
| 47 |
+
|
| 48 |
+
print(f"Judge's Reasoning: {result.reasoning}")
|
| 49 |
+
print(f"==================================================\n")
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
# Test our agent with a real query!
|
| 53 |
+
evaluate_real_agent("What is OmniRouter and what does it do?")
|
src/providers/__init__.py
ADDED
|
File without changes
|
src/providers/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
src/providers/__pycache__/anthropic_client.cpython-312.pyc
ADDED
|
Binary file (2.92 kB). View file
|
|
|
src/providers/__pycache__/base.cpython-312.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
src/providers/__pycache__/openai_client.cpython-312.pyc
ADDED
|
Binary file (3.23 kB). View file
|
|
|
src/providers/anthropic_client.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from anthropic import AsyncAnthropic
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from src.schemas import RouterConfig, LLMResponse
|
| 5 |
+
from src.providers.base import BaseLLMProvider
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class AnthropicProvider(BaseLLMProvider):
|
| 10 |
+
def __init__(self, api_key: str):
|
| 11 |
+
super().__init__(api_key)
|
| 12 |
+
self.client = AsyncAnthropic(api_key=self.api_key)
|
| 13 |
+
|
| 14 |
+
async def async_generate(self, prompt: str, config: RouterConfig) -> LLMResponse:
|
| 15 |
+
logger.info(f"Routing request to Anthropic using model: {config.model}")
|
| 16 |
+
|
| 17 |
+
# Anthropic's API structure is slightly different from OpenAI's
|
| 18 |
+
response = await self.client.messages.create(
|
| 19 |
+
model=config.model,
|
| 20 |
+
max_tokens=1024, # Anthropic requires max_tokens to be explicitly set
|
| 21 |
+
messages=[{"role": "user", "content": prompt}],
|
| 22 |
+
temperature=config.temperature,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
content = response.content[0].text
|
| 26 |
+
prompt_tokens = response.usage.input_tokens
|
| 27 |
+
completion_tokens = response.usage.output_tokens
|
| 28 |
+
|
| 29 |
+
cost = self.calculate_cost(prompt_tokens, completion_tokens, config.model)
|
| 30 |
+
|
| 31 |
+
return LLMResponse(
|
| 32 |
+
content=content,
|
| 33 |
+
provider_used="anthropic",
|
| 34 |
+
model_used=config.model,
|
| 35 |
+
prompt_tokens=prompt_tokens,
|
| 36 |
+
completion_tokens=completion_tokens,
|
| 37 |
+
cost_estimate=cost
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def calculate_cost(self, prompt_tokens: int, completion_tokens: int, model_name: str) -> float:
|
| 41 |
+
pricing = {
|
| 42 |
+
"claude-3-opus-20240229": {"prompt": 15.0, "completion": 75.0},
|
| 43 |
+
"claude-3-5-sonnet-20240620": {"prompt": 3.0, "completion": 15.0},
|
| 44 |
+
"claude-3-haiku-20240307": {"prompt": 0.25, "completion": 1.25}
|
| 45 |
+
}
|
| 46 |
+
rates = pricing.get(model_name, {"prompt": 0.0, "completion": 0.0})
|
| 47 |
+
cost = (prompt_tokens / 1_000_000) * rates["prompt"] + \
|
| 48 |
+
(completion_tokens / 1_000_000) * rates["completion"]
|
| 49 |
+
return round(cost, 6)
|
src/providers/base.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The abstract blueprint every provider must follow
|
| 3 |
+
"""
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from src.schemas import RouterConfig, LLMResponse
|
| 7 |
+
|
| 8 |
+
class BaseLLMProvider(ABC):
|
| 9 |
+
"""
|
| 10 |
+
The strict blueprint that ALL LLM providers must follow.
|
| 11 |
+
If a developer tries to create a provider without an 'async_generate' method,
|
| 12 |
+
Python will throw a TypeError upon instantiation.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, api_key: str):
|
| 16 |
+
# Every provider needs an API key (or a dummy key for local models)
|
| 17 |
+
self.api_key = api_key
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
async def async_generate(
|
| 21 |
+
self,
|
| 22 |
+
prompt: str,
|
| 23 |
+
config: RouterConfig
|
| 24 |
+
) -> LLMResponse:
|
| 25 |
+
"""
|
| 26 |
+
The core engine method.
|
| 27 |
+
Takes a string prompt and our strict RouterConfig.
|
| 28 |
+
MUST return our strictly typed LLMResponse.
|
| 29 |
+
"""
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def calculate_cost(
|
| 34 |
+
self,
|
| 35 |
+
prompt_tokens: int,
|
| 36 |
+
completion_tokens: int,
|
| 37 |
+
model_name: str
|
| 38 |
+
) -> float:
|
| 39 |
+
"""
|
| 40 |
+
Calculates the estimated cost of the API call.
|
| 41 |
+
Essential for production monitoring.
|
| 42 |
+
"""
|
| 43 |
+
pass
|
src/providers/openai_client.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import AsyncOpenAI
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
# Import our strict schemas and base blueprint
|
| 5 |
+
from src.schemas import RouterConfig, LLMResponse
|
| 6 |
+
from src.providers.base import BaseLLMProvider
|
| 7 |
+
|
| 8 |
+
# Set up logging for professional debugging
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class OpenAIProvider(BaseLLMProvider):
|
| 12 |
+
"""
|
| 13 |
+
The concrete implementation for OpenAI's API.
|
| 14 |
+
How it strictly fulfills the contract defined in BaseLLMProvider.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, api_key: str):
|
| 18 |
+
# Call the parent class initialization
|
| 19 |
+
super().__init__(api_key)
|
| 20 |
+
|
| 21 |
+
# CRITICAL: We initialize the ASYNC client, not the standard synchronous one.
|
| 22 |
+
# This is what allows our router to handle hundreds of concurrent requests.
|
| 23 |
+
self.client = AsyncOpenAI(api_key=self.api_key)
|
| 24 |
+
|
| 25 |
+
async def async_generate(self, prompt: str, config: RouterConfig) -> LLMResponse:
|
| 26 |
+
logger.info(f"Routing request to OpenAI using model: {config.model}")
|
| 27 |
+
|
| 28 |
+
# 1. Execute the Async API Call
|
| 29 |
+
response = await self.client.chat.completions.create(
|
| 30 |
+
model=config.model,
|
| 31 |
+
messages=[{"role": "user", "content": prompt}],
|
| 32 |
+
temperature=config.temperature,
|
| 33 |
+
# We will add top_p, frequency_penalty, etc. later as needed
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# 2. Extract Data from OpenAI's specific object structure
|
| 37 |
+
content = response.choices[0].message.content
|
| 38 |
+
prompt_tokens = response.usage.prompt_tokens
|
| 39 |
+
completion_tokens = response.usage.completion_tokens
|
| 40 |
+
|
| 41 |
+
# 3. Calculate Cost dynamically
|
| 42 |
+
cost = self.calculate_cost(prompt_tokens, completion_tokens, config.model)
|
| 43 |
+
|
| 44 |
+
# 4. Standardize the Output
|
| 45 |
+
# We transform OpenAI's proprietary response into our universal LLMResponse schema.
|
| 46 |
+
# Now, the rest of our application doesn't need to know anything about OpenAI's specific formatting.
|
| 47 |
+
return LLMResponse(
|
| 48 |
+
content=content,
|
| 49 |
+
provider_used="openai",
|
| 50 |
+
model_used=config.model,
|
| 51 |
+
prompt_tokens=prompt_tokens,
|
| 52 |
+
completion_tokens=completion_tokens,
|
| 53 |
+
cost_estimate=cost
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def calculate_cost(self, prompt_tokens: int, completion_tokens: int, model_name: str) -> float:
|
| 57 |
+
"""
|
| 58 |
+
Calculates the exact cost of the API call based on OpenAI's pricing (per 1M tokens).
|
| 59 |
+
This is a massive value-add for your open-source repository.
|
| 60 |
+
"""
|
| 61 |
+
# A dictionary acting as a simple pricing database
|
| 62 |
+
pricing = {
|
| 63 |
+
"gpt-4-turbo": {"prompt": 10.0, "completion": 30.0},
|
| 64 |
+
"gpt-4o": {"prompt": 5.0, "completion": 15.0},
|
| 65 |
+
"gpt-3.5-turbo": {"prompt": 0.5, "completion": 1.5}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# If they use a model not in our dict, default to 0.0 to prevent crashes
|
| 69 |
+
rates = pricing.get(model_name, {"prompt": 0.0, "completion": 0.0})
|
| 70 |
+
|
| 71 |
+
# Math: (tokens / 1,000,000) * rate
|
| 72 |
+
cost = (prompt_tokens / 1_000_000) * rates["prompt"] + \
|
| 73 |
+
(completion_tokens / 1_000_000) * rates["completion"]
|
| 74 |
+
|
| 75 |
+
return round(cost, 6)
|
src/rag/__pycache__/chatbot.cpython-312.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
src/rag/__pycache__/ingestion.cpython-312.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
src/rag/__pycache__/vector_store.cpython-312.pyc
ADDED
|
Binary file (1.89 kB). View file
|
|
|
src/rag/chatbot.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_openai import ChatOpenAI
|
| 2 |
+
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
| 3 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 4 |
+
# from langchain.chains.history_aware_retriever import create_history_aware_retriever
|
| 5 |
+
# from langchain.chains.retrieval import create_retrieval_chain
|
| 6 |
+
# from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 7 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 8 |
+
|
| 9 |
+
from src.rag.vector_store import get_vector_store
|
| 10 |
+
|
| 11 |
+
def build_doc_assistant(api_key: str):
|
| 12 |
+
"""
|
| 13 |
+
Constructs the conversational RAG pipeline.
|
| 14 |
+
"""
|
| 15 |
+
# 1. Initialize our LLM (temperature=0 because we want factual answers, not creative ones)
|
| 16 |
+
llm = ChatOpenAI(api_key=api_key, model="gpt-3.5-turbo", temperature=0)
|
| 17 |
+
|
| 18 |
+
# 2. Connect to our Vector DB (k=2 means return the top 2 most relevant chunks)
|
| 19 |
+
retriever = get_vector_store(api_key).as_retriever(search_kwargs={"k": 2})
|
| 20 |
+
|
| 21 |
+
# ==========================================
|
| 22 |
+
# STEP 1: The "Question Reformulation" Prompt
|
| 23 |
+
# ==========================================
|
| 24 |
+
contextualize_q_system_prompt = (
|
| 25 |
+
"Given a chat history and the latest user question "
|
| 26 |
+
"which might reference context in the chat history, "
|
| 27 |
+
"formulate a standalone question which can be understood "
|
| 28 |
+
"without the chat history. Do NOT answer the question, "
|
| 29 |
+
"just reformulate it if needed and otherwise return it as is."
|
| 30 |
+
)
|
| 31 |
+
contextualize_q_prompt = ChatPromptTemplate.from_messages([
|
| 32 |
+
("system", contextualize_q_system_prompt),
|
| 33 |
+
MessagesPlaceholder("chat_history"), # Injects our memory here
|
| 34 |
+
("human", "{input}"),
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
# This chain automatically handles rewriting the query before searching
|
| 38 |
+
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
|
| 39 |
+
|
| 40 |
+
# ==========================================
|
| 41 |
+
# STEP 2: The "Final Answer" Prompt
|
| 42 |
+
# ==========================================
|
| 43 |
+
system_prompt = (
|
| 44 |
+
"You are an elite AI Engineering Assistant. "
|
| 45 |
+
"Use the following pieces of retrieved context to answer the question. "
|
| 46 |
+
"If the answer is not contained in the context, say 'I don't know based on the documentation.' "
|
| 47 |
+
"Do not make up an answer. Keep it concise.\n\n"
|
| 48 |
+
"Context: {context}"
|
| 49 |
+
)
|
| 50 |
+
qa_prompt = ChatPromptTemplate.from_messages([
|
| 51 |
+
("system", system_prompt),
|
| 52 |
+
MessagesPlaceholder("chat_history"),
|
| 53 |
+
("human", "{input}"),
|
| 54 |
+
])
|
| 55 |
+
|
| 56 |
+
# This chain handles injecting the retrieved chunks into the {context} variable
|
| 57 |
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
| 58 |
+
|
| 59 |
+
# ==========================================
|
| 60 |
+
# STEP 3: Tie it all together
|
| 61 |
+
# ==========================================
|
| 62 |
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
| 63 |
+
return rag_chain
|
src/rag/ingestion.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 2 |
+
from langchain_core.documents import Document
|
| 3 |
+
|
| 4 |
+
def chunk_document_text(raw_text: str):
|
| 5 |
+
"""
|
| 6 |
+
Simulates taking a massive document and chunking it for a Vector Store.
|
| 7 |
+
"""
|
| 8 |
+
print(f"Original Document Length: {len(raw_text)} characters")
|
| 9 |
+
|
| 10 |
+
# THE CHUNKER CONFIGURATION
|
| 11 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 12 |
+
chunk_size=100, # The maximum size of each chunk
|
| 13 |
+
chunk_overlap=20, # How much the chunks should overlap
|
| 14 |
+
length_function=len,
|
| 15 |
+
separators=["\n\n", "\n", " ", ""] # Tries to split at paragraphs first, then sentences
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Create a LangChain Document object
|
| 19 |
+
doc = Document(page_content=raw_text, metadata={"source": "engineering_manual.pdf"})
|
| 20 |
+
|
| 21 |
+
# Execute the split
|
| 22 |
+
chunks = text_splitter.split_documents([doc])
|
| 23 |
+
|
| 24 |
+
print(f"\nCreated {len(chunks)} chunks.")
|
| 25 |
+
|
| 26 |
+
# Let's inspect the exact output to understand the data structure
|
| 27 |
+
for i, chunk in enumerate(chunks):
|
| 28 |
+
print(f"\n--- Chunk {i+1} ---")
|
| 29 |
+
print(chunk.page_content)
|
| 30 |
+
|
| 31 |
+
return chunks
|
| 32 |
+
# Let's test it with a sample "manual"
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
sample_manual = (
|
| 35 |
+
"OmniRouter is an advanced asynchronous LLM routing engine. "
|
| 36 |
+
"It is designed to handle multiple providers gracefully. "
|
| 37 |
+
"If the primary provider fails, the system initiates a failover protocol. "
|
| 38 |
+
"This ensures maximum uptime for production systems."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
chunk_document_text(sample_manual)
|
src/rag/vector_store.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from langchain_community.vectorstores import Chroma
|
| 5 |
+
|
| 6 |
+
# OpenAI embedding
|
| 7 |
+
# from langchain_openai import OpenAIEmbeddings
|
| 8 |
+
|
| 9 |
+
# Free local embedding
|
| 10 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 11 |
+
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# # Huggingface api key...
|
| 16 |
+
# os.environ["HF_TOKEN"] = "hf_PWDT"
|
| 17 |
+
|
| 18 |
+
# This is where our local database will be saved on your hard drive
|
| 19 |
+
DB_DIRECTORY = "./chroma_db"
|
| 20 |
+
|
| 21 |
+
def get_embeddings_model():
|
| 22 |
+
"""Returns the active embedding model."""
|
| 23 |
+
# --- FREE PIPELINE ---
|
| 24 |
+
# This downloads a small, highly efficient open-source model to your machine.
|
| 25 |
+
print("Loading HuggingFace Embeddings...")
|
| 26 |
+
# api_key = os.getenv("HUGGINGFACE_API_KEY")
|
| 27 |
+
return HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 28 |
+
|
| 29 |
+
# --- PAID PIPELINE (Uncomment when you have credits) ---
|
| 30 |
+
# We use OpenAI's embedding model here. It converts text to 1536-dimensional vectors.
|
| 31 |
+
# api_key = os.getenv("OPENAI_API_KEY")
|
| 32 |
+
# return OpenAIEmbeddings(api_key=api_key, model="text-embedding-3-small")
|
| 33 |
+
|
| 34 |
+
def build_vector_store(chunks: List[Document], api_key: str):
|
| 35 |
+
"""
|
| 36 |
+
Takes a list of chunked documents, embeds them, and saves them to a local Chroma database.
|
| 37 |
+
"""
|
| 38 |
+
embeddings = get_embeddings_model()
|
| 39 |
+
|
| 40 |
+
print(f"Embedding {len(chunks)} chunks and saving to {DB_DIRECTORY}...")
|
| 41 |
+
|
| 42 |
+
# 1. Create the database
|
| 43 |
+
# 2. Embed all the chunks
|
| 44 |
+
# 3. Save it to the DB_DIRECTORY
|
| 45 |
+
vector_store = Chroma.from_documents(
|
| 46 |
+
documents=chunks,
|
| 47 |
+
embedding=embeddings,
|
| 48 |
+
persist_directory=DB_DIRECTORY
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Force the database to save to disk
|
| 52 |
+
vector_store.persist()
|
| 53 |
+
print("Database successfully built and saved to disk!")
|
| 54 |
+
return vector_store
|
| 55 |
+
|
| 56 |
+
def get_vector_store(api_key: str):
|
| 57 |
+
"""
|
| 58 |
+
Retrieves the existing database from the hard drive so we don't have to rebuild it every time.
|
| 59 |
+
"""
|
| 60 |
+
embeddings = get_embeddings_model()
|
| 61 |
+
return Chroma(persist_directory=DB_DIRECTORY, embedding_function=embeddings)
|
src/router.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
|
| 5 |
+
# Import our schemas and providers
|
| 6 |
+
from src.schemas import RouterConfig, LLMResponse
|
| 7 |
+
from src.providers.base import BaseLLMProvider
|
| 8 |
+
from src.providers.openai_client import OpenAIProvider
|
| 9 |
+
from src.providers.anthropic_client import AnthropicProvider
|
| 10 |
+
|
| 11 |
+
# Set up logging so we can see the retries happening in the terminal
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
class OmniRouter:
|
| 16 |
+
"""
|
| 17 |
+
The central routing engine.
|
| 18 |
+
Handles provider selection, retries, and error management.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, api_keys: Dict[str, str]):
|
| 22 |
+
"""
|
| 23 |
+
We initialize the router with a dictionary of API keys.
|
| 24 |
+
We then map string names (like 'openai') to their concrete class instances.
|
| 25 |
+
"""
|
| 26 |
+
self.providers: Dict[str, BaseLLMProvider] = {}
|
| 27 |
+
|
| 28 |
+
# If the user passed an OpenAI key, activate the OpenAI provider
|
| 29 |
+
if "openai" in api_keys:
|
| 30 |
+
self.providers["openai"] = OpenAIProvider(api_key=api_keys["openai"])
|
| 31 |
+
|
| 32 |
+
#---Register Anthropic
|
| 33 |
+
if "anthropic" in api_keys:
|
| 34 |
+
self.providers["anthropic"] = AnthropicProvider(api_key=api_keys["anthropic"])
|
| 35 |
+
# We will add others here later
|
| 36 |
+
|
| 37 |
+
async def generate(self, prompt: str, config: RouterConfig) -> LLMResponse:
|
| 38 |
+
"""
|
| 39 |
+
The main entry point. Routes the prompt to the correct provider with retries.
|
| 40 |
+
"""
|
| 41 |
+
# 1. Check if the requested provider actually exists in our dictionary
|
| 42 |
+
provider = self.providers.get(config.provider)
|
| 43 |
+
if not provider:
|
| 44 |
+
raise ValueError(f"Provider '{config.provider}' is not configured.")
|
| 45 |
+
|
| 46 |
+
last_exception = None
|
| 47 |
+
|
| 48 |
+
# 2. PRIMARY RETRY LOOP
|
| 49 |
+
for attempt in range(config.max_retries):
|
| 50 |
+
try:
|
| 51 |
+
# If this is a retry, log it
|
| 52 |
+
if attempt > 0:
|
| 53 |
+
logger.info(f"[{config.provider}] Retrying... Attempt {attempt + 1} of {config.max_retries}")
|
| 54 |
+
|
| 55 |
+
# 3. The actual API call to whatever provider is currently selected
|
| 56 |
+
response = await provider.async_generate(prompt, config)
|
| 57 |
+
return response
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
# If the API crashes, we catch it here instead of crashing the app
|
| 61 |
+
logger.warning(f"[{config.provider}] Attempt {attempt + 1} failed with error: {str(e)}")
|
| 62 |
+
last_exception = e
|
| 63 |
+
|
| 64 |
+
# 4. EXPONENTIAL BACKOFF
|
| 65 |
+
# Wait 2^attempt seconds (1s, 2s, 4s, 8s...) before trying again
|
| 66 |
+
wait_time = 2 ** attempt
|
| 67 |
+
logger.info(f"Waiting {wait_time} seconds before next attempt...")
|
| 68 |
+
await asyncio.sleep(wait_time)
|
| 69 |
+
|
| 70 |
+
# 2. FAILOVER LOGIC (The Holy Grail)
|
| 71 |
+
if config.fallback_provider:
|
| 72 |
+
logger.error(f"🚨 Primary provider '{config.provider}' exhausted all retries. Initiating FAILOVER to '{config.fallback_provider}'...")
|
| 73 |
+
|
| 74 |
+
# Create a new config for the fallback provider
|
| 75 |
+
fallback_config = RouterConfig(
|
| 76 |
+
provider=config.fallback_provider,
|
| 77 |
+
model=config.fallback_model or config.model, # Use specific fallback model if provided
|
| 78 |
+
temperature=config.temperature,
|
| 79 |
+
max_retries=config.max_retries
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Recursively call generate with the new config!
|
| 83 |
+
return await self.generate(prompt, fallback_config)
|
| 84 |
+
|
| 85 |
+
# 5. If we loop through all max_retries and still fail, crash gracefully
|
| 86 |
+
logger.error(f"All {config.max_retries} attempts failed and no fallback configured.")
|
| 87 |
+
raise last_exception
|
src/schemas.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
In LLMs, you are querying a probabilistic text engine.
|
| 3 |
+
If you ask it for an age, it might give you 25, or it might give you "Twenty-five", or it might say "Based on the data, the user is 25 years old."
|
| 4 |
+
|
| 5 |
+
If your system expects 25 but gets a whole sentence, your code crashes in production.
|
| 6 |
+
|
| 7 |
+
**Schemas act as the bouncers at the door of your application. We use Pydantic to define the exact shape of the data we expect.**
|
| 8 |
+
"""
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
|
| 12 |
+
# 1. THE CONFIGURATION SCHEMA
|
| 13 |
+
# This dictates how our router behaves. Notice how we set smart defaults.
|
| 14 |
+
class RouterConfig(BaseModel):
|
| 15 |
+
"""Configuration for how the OmniRouter should route the request."""
|
| 16 |
+
provider: str = Field(
|
| 17 |
+
default="openai",
|
| 18 |
+
description="The LLM provider to use (e.g., 'openai', 'anthropic', 'local')"
|
| 19 |
+
)
|
| 20 |
+
model: str = Field(
|
| 21 |
+
default="gpt-4-turbo",
|
| 22 |
+
description="The specific model string to use"
|
| 23 |
+
)
|
| 24 |
+
# Defensive Engineering: An LLM temperature cannot be less than 0 or greater than 2.
|
| 25 |
+
# ge = greater than or equal to, le = less than or equal to.
|
| 26 |
+
temperature: float = Field(
|
| 27 |
+
default=0.7,
|
| 28 |
+
ge=0.0,
|
| 29 |
+
le=2.0,
|
| 30 |
+
description="Creativity score for the model"
|
| 31 |
+
)
|
| 32 |
+
max_retries: int = Field(
|
| 33 |
+
default=3,
|
| 34 |
+
description="How many times to retry on API failure or rate limit"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# --- NEW CAPABILITY --- Added
|
| 38 |
+
fallback_provider: Optional[str] = Field(
|
| 39 |
+
default=None,
|
| 40 |
+
description="If the primary provider completely fails, switch to this one"
|
| 41 |
+
)
|
| 42 |
+
fallback_model: Optional[str] = Field(
|
| 43 |
+
default=None,
|
| 44 |
+
description="The model to use for the fallback provider"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# 2. THE STANDARDIZED OUTPUT SCHEMA
|
| 48 |
+
# This solves the main pain point from our README.
|
| 49 |
+
# Whether OpenAI or Anthropic answers, the rest of our app gets THIS exact object.
|
| 50 |
+
class LLMResponse(BaseModel):
|
| 51 |
+
"""The standardized output format returned from ANY provider."""
|
| 52 |
+
content: str = Field(description="The actual text response generated by the LLM")
|
| 53 |
+
provider_used: str = Field(description="Which provider actually generated this response")
|
| 54 |
+
model_used: str = Field(description="The specific model used")
|
| 55 |
+
|
| 56 |
+
# We track tokens heavily for cost optimization (Week 14 concept)
|
| 57 |
+
prompt_tokens: int = Field(default=0, description="Tokens used in the prompt")
|
| 58 |
+
completion_tokens: int = Field(default=0, description="Tokens used in the completion")
|
| 59 |
+
|
| 60 |
+
# We will calculate this automatically later
|
| 61 |
+
cost_estimate: float = Field(default=0.0, description="Estimated cost of this call in USD")
|