sumitrwk commited on
Commit
b534a53
·
verified ·
1 Parent(s): 4f530ea

Upload 33 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1. The Foundation
2
+ # We start with a lightweight, official Linux image with Python 3.12 pre-installed.
3
+ FROM python:3.12-slim
4
+
5
+ # 2. Environment variables
6
+ # Prevent python from writing messy .pyc files
7
+ ENV PYTHONDONTWRITEBYTECODE=1
8
+ # Ensure our terminal print() statements show up immediately in cloud logs
9
+ ENV PYTHONUNBUFFERED=1
10
+ # Tell HuggingFaceexactly where to save its 100MB math model
11
+ ENV HF_HOME=/app/.cache/huggingface
12
+
13
+ # 3. The Workspace
14
+ # Create a folder inside the container called /app and move inside it
15
+ WORKDIR /app
16
+
17
+ # 4. Cache optimization (the Architect's Trick)
18
+ # We ONLY copy the requirements file first.
19
+ # Docker caches steps. If you change your Python code later, Docker won't
20
+ # force you to sit through a 5-minute re-installation of Pandas and LangChain!
21
+ COPY requirements.txt .
22
+
23
+ # Install the Python packages
24
+ RUN pip install --no-cache-dir -r requirements.txt
25
+
26
+ # 5. COPY the payload
27
+ # Now we copy the rest of your actual code into the container
28
+ COPY . .
29
+
30
+ # 6. OPEN the gate
31
+ # Tell the container to allow trafic on port 8000
32
+ EXPOSE 8000
33
+
34
+ # 7. The ignition switch
35
+ # The exact terminal command the container runs when it wakes up in the cloud.
36
+ # Notice we use 0.0.0.0 so the cloud provider's router can find it.
37
+ CMD ["uvicorn", "src.api.server:app", "--host", "0.0.0.0", "--port", "7860"]
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pydantic
2
+ openai
3
+ anthropic
4
+ langchain-openai
5
+
6
+ langchain
7
+ chromadb
8
+ tiktoken
9
+ langchain-text-splitters
10
+ langchain-core
11
+ langchain-community
12
+ langgraph
13
+
14
+ setuptools
15
+
16
+ # Free tier...
17
+ langchain-groq
18
+ langchain-huggingface
19
+ sentence-transformers
20
+
21
+ python-dotenv
22
+
23
+ fastapi
24
+ uvicorn
25
+
26
+ requests
seed_db.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.rag.vector_store import build_vector_store
2
+ from langchain_core.documents import Document
3
+ import os
4
+
5
+ api_key = os.getenv("HF_TOKEN")
6
+
7
+ def seed_database():
8
+ print("Seeding new HuggingFace database...")
9
+
10
+ # 1. Our dummy text
11
+ sample_text = (
12
+ "OmniRouter is an enterprise-grade AI architecture combining high-concurrency "
13
+ "LLM routing and local Vector Database retrieval. If the primary API fails, "
14
+ "it seamlessly switches to a fallback model. It uses LangGraph for agentic reasoning."
15
+ )
16
+
17
+ # 2. Package it as a chunk
18
+ doc = Document(page_content=sample_text, metadata={"source": "manual.pdf"})
19
+
20
+ # 3. Build and save the DB
21
+ build_vector_store([doc], api_key=api_key)
22
+
23
+ if __name__ == "__main__":
24
+ seed_database()
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (126 Bytes). View file
 
src/__pycache__/router.cpython-312.pyc ADDED
Binary file (4.15 kB). View file
 
src/__pycache__/schemas.cpython-312.pyc ADDED
Binary file (2.75 kB). View file
 
src/agent/__pycache__/graph.cpython-312.pyc ADDED
Binary file (3.42 kB). View file
 
src/agent/__pycache__/tools.cpython-312.pyc ADDED
Binary file (2.56 kB). View file
 
src/agent/graph.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Annotated, TypedDict
3
+ # # OpenAI...
4
+ # from langchain_openai import ChatOpenAI
5
+ # Groq LLM...
6
+ from langchain_groq import ChatGroq
7
+ from langchain_core.messages import BaseMessage
8
+ from langgraph.graph import StateGraph, START, END
9
+ from langgraph.graph.message import add_messages
10
+ from langgraph.prebuilt import ToolNode, tools_condition
11
+ from src.agent.tools import search_documentation
12
+ # Import our custom tool
13
+ from src.agent.tools import search_documentation
14
+ # Fix the infinite loop using system prompt(Guardrail)
15
+ from langchain_core.messages import SystemMessage
16
+ # Human in the loop using LangGraphs checkpointers...
17
+ from langgraph.checkpoint.memory import MemorySaver
18
+
19
+
20
+ from dotenv import load_dotenv
21
+ load_dotenv()
22
+
23
+
24
+ # 1. UPGRADED STATE
25
+ class AgentState(TypedDict):
26
+ # 'add_messages' ensures we append to the history, never overwrite it.
27
+ messages: Annotated[list[BaseMessage], add_messages]
28
+
29
+ # 2. INITIALIZE THE BRAIN
30
+ # We instantiate the LLM and "bind" our tool to it.
31
+
32
+ # # Make sure you export GROQ_API_KEY in your terminal before running!
33
+ # os.environ["GROQ_API_KEY"] = "gsk_jd"
34
+ # We use Meta's Llama 3 8B model hosted on Groq for incredible speed
35
+ llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0)
36
+
37
+ # # This sends the JSON schema we looked at yesterday directly to OpenAI.
38
+ # llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
39
+
40
+ tools = [search_documentation]
41
+ llm_with_tools = llm.bind_tools(tools)
42
+
43
+ # 3. THE NODES
44
+ """
45
+ Chat node for seamless interaction with the LLM model...
46
+ """
47
+ def chatbot_node(state: AgentState):
48
+ """
49
+ This node intercepts the history, adds strict behavioral rules,
50
+ and then passes it to the LLM.
51
+
52
+ The LLM will either return a standard text message, OR a special "ToolCall" message.
53
+ """
54
+ print("\n--- [NODE: Chatbot] Thinking... ---")
55
+
56
+ # 1. The Circuit Breaker System Prompt --> Prompt based Flow control
57
+ system_message = SystemMessage(content=(
58
+ "You are an elite AI Engineering assistant. "
59
+ "You have access to a search_documentation tool. "
60
+ "CRITICAL RULE: If you use the tool and it returns 'No relevant information found', "
61
+ "you MUST NOT use the tool again. Immediately stop and tell the user "
62
+ "'I do not have enough information to answer that based on the documentation.' "
63
+ "Do not guess. Do not hallucinate."
64
+ ))
65
+
66
+ # 2. Prepend the system rules to the chat history
67
+ messages_to_send = [system_message] + state["messages"]
68
+
69
+ # 3. Invoke the LLM with the strict rules applied
70
+ response = llm_with_tools.invoke(messages_to_send)
71
+
72
+ # We return the message wrapped in a list to trigger the 'add_messages' append behavior
73
+ return {"messages": [response]}
74
+
75
+ # LangGraph has a built-in node specifically for executing tools!
76
+ # It reads the "ToolCall" message, runs our Python function, and returns a "ToolMessage".
77
+ tool_node = ToolNode(tools=tools)
78
+
79
+ # =============================================
80
+ # 4. COMPILE THE GRAPH with human in the loop...
81
+ # =============================================
82
+ workflow = StateGraph(AgentState)
83
+
84
+ # Add our two worker nodes
85
+ workflow.add_node("chatbot", chatbot_node)
86
+ workflow.add_node("tools", tool_node)
87
+
88
+ # Set the entry point
89
+ workflow.add_edge(START, "chatbot")
90
+ # 5. THE MAGIC ROUTING
91
+ # 'tools_condition' is a built-in LangGraph edge.
92
+ # It looks at the last message from the chatbot.
93
+ # If it has a tool call, it routes to "tools". If it's just text, it routes to END.
94
+ workflow.add_conditional_edges("chatbot", tools_condition)
95
+ # After a tool finishes running, ALWAYS loop back to the chatbot
96
+ # so it can read the database results and formulate a final answer!
97
+ workflow.add_edge("tools", "chatbot")
98
+
99
+ # Initialize the short-term memory vault
100
+ memory = MemorySaver()
101
+
102
+ # Compile with the memory and the breakpoint!
103
+ app = workflow.compile(
104
+ checkpointer=memory,
105
+ interrupt_before=["tools"] # tell the graph to pause before executing this node.
106
+ )
107
+
108
+
109
+ from langchain_core.messages import HumanMessage
110
+
111
+ if __name__ == "__main__":
112
+ # Ensure your API key is available
113
+ # os.environ["OPENAI_API_KEY"] = "YOUR_REAL_OPENAI_API_KEY"
114
+
115
+ print("========== AGENT TEST ==========")
116
+ initial_state = {
117
+ "messages": [HumanMessage(content="What does OmniRouter do?")]
118
+ }
119
+
120
+ # stream() allows us to see the exact output of each node as it executes!
121
+ for event in app.stream(initial_state):
122
+ for node_name, node_state in event.items():
123
+ print(f"Update from node '{node_name}':")
124
+ # Print the content of the very last message added to the state
125
+ print(f" -> {node_state['messages'][-1].content}\n")
src/agent/tools.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Let's officially give your agent its brain.
3
+
4
+ We are going to use LangChain's @tool decorator.
5
+ This magical little wrapper takes your standard Python function,
6
+ reads the type hints (like query: str), reads the docstring, and automatically
7
+ translates the entire thing into a strict JSON schema that OpenAI and Anthropic natively understand .
8
+ """
9
+ import os
10
+ from langchain_core.tools import tool
11
+ from src.rag.vector_store import get_vector_store
12
+
13
+ # The @tool decorator converts this Python function into an LLM-readable JSON schema
14
+ @tool
15
+ def search_documentation(query: str) -> str:
16
+ """
17
+ Searches the internal engineering documentation for information about the OmniRouter,
18
+ LangChain, LangGraph, or general AI engineering concepts.
19
+
20
+ Use this tool WHENEVER the user asks a technical question about how the system works,
21
+ fallback protocols, or specific coding architecture. Do NOT use this for general greetings.
22
+
23
+ Args:
24
+ query: The specific search term to look up in the database.
25
+ It should be a standalone, highly descriptive phrase.
26
+ """
27
+ print(f"\n--- [TOOL EXECUTION] Searching Vector DB for: '{query}' ---")
28
+
29
+ # In production, ensure your API key is loaded securely
30
+ api_key = os.getenv("HF_TOKEN")
31
+
32
+ try:
33
+ db = get_vector_store(api_key)
34
+
35
+ # Retrieve the top 2 most relevant chunks
36
+ results = db.similarity_search(query, k=2)
37
+
38
+ if not results:
39
+ return "No relevant information found in the documentation."
40
+
41
+ # We must return a STRING, not a list of objects, so the LLM can read it easily
42
+ combined_text = "\n\n".join([doc.page_content for doc in results])
43
+ return combined_text
44
+
45
+ except Exception as e:
46
+ # FIXED: Print the error to the server terminal so we can see it!
47
+ print(f"\n🚨 [TOOL CRASHED]: {str(e)}")
48
+ return f"Error executing search: {str(e)}"
49
+
50
+
51
+ if __name__ == "__main__":
52
+ # Print the name the LLM sees
53
+ print(f"Tool Name: {search_documentation.name}")
54
+
55
+ # Print the description the LLM reads to make its decision
56
+ print(f"\nTool Description: \n{search_documentation.description}")
57
+
58
+ # Print the strict JSON schema the LLM must follow to use the tool
59
+ print(f"\nTool Arguments Schema: \n{search_documentation.args}")
src/api/__pycache__/cache.cpython-312.pyc ADDED
Binary file (2.05 kB). View file
 
src/api/__pycache__/server.cpython-312.pyc ADDED
Binary file (2.72 kB). View file
 
src/api/cache.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+ from langchain_community.vectorstores import Chroma
3
+ from langchain_core.documents import Document
4
+
5
+ CACHE_DIR = "./semantic_cache_db"
6
+
7
+ def get_cache_db():
8
+ """Initializes the Cache Database using free local embeddings."""
9
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
10
+ return Chroma(persist_directory=CACHE_DIR, embedding_function=embeddings)
11
+
12
+ def check_cache(query: str, threshold: float = 0.5) -> str | None:
13
+ """
14
+ Embeds the user's question and mathematically checks if anyone
15
+ has asked a highly similar question before.
16
+ """
17
+ db = get_cache_db()
18
+
19
+ # We search the cache and ask for the 'distance score'
20
+ results = db.similarity_search_with_score(query, k=1)
21
+
22
+ if results:
23
+ doc, score = results[0]
24
+ # In ChromaDB's default math (L2 distance), a LOWER score means it's MORE similar.
25
+ # 0.0 is an exact match. 0.5 means "very similar meaning".
26
+ if score < threshold:
27
+ print(f"\n🟢 [CACHE HIT] Similar question found! (Score: {score:.3f})")
28
+ return doc.metadata.get("answer")
29
+
30
+ print("\n🔴 [CACHE MISS] Question is new.")
31
+ return None
32
+
33
+ def save_to_cache(query: str, answer: str):
34
+ """Saves a brand new question and its answer into the database."""
35
+ db = get_cache_db()
36
+
37
+ # The 'page_content' is the question. The 'metadata' holds the answer.
38
+ doc = Document(page_content=query, metadata={"answer": answer})
39
+ db.add_documents([doc])
40
+
41
+ print("\n💾 [CACHE SAVED] New interaction stored for future users.")
src/api/server.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import asyncio
3
+ from fastapi import FastAPI #, HTTPException
4
+ from fastapi.responses import StreamingResponse
5
+ from pydantic import BaseModel
6
+ from typing import List
7
+ from langchain_core.messages import HumanMessage, AIMessage
8
+
9
+ # Import our compiled LangGraph agent
10
+ from src.agent.graph import app as agent_app
11
+ from src.api.cache import check_cache, save_to_cache
12
+
13
+ # 1. Initialize the FastAPI Server
14
+ app = FastAPI(
15
+ title="OmniRouter Streaming API Agent",
16
+ description="Enterprise RAG Agent powered by LangGraph and FastAPI",
17
+ version="1.0.1"
18
+ )
19
+
20
+ # 2. Define our Request Schema using Pydantic
21
+ class ChatRequest(BaseModel):
22
+ query: str
23
+
24
+ async def stream_generator(query: str):
25
+ # ==========================================
26
+ # 1. THE CACHE LAYER (Lightning Fast)
27
+ # ==========================================
28
+ cached_answer = check_cache(query)
29
+
30
+ if cached_answer:
31
+ # We chop the cached string into words and stream them instantly
32
+ for word in cached_answer.split(" "):
33
+ yield f"data: {json.dumps({'token': word + ' '})}\n\n"
34
+ # We add a tiny 20ms sleep just to preserve the "typewriter" feel for the user
35
+ await asyncio.sleep(0.02)
36
+ return # EXIT EARLY! The LLM is never triggered.
37
+
38
+ """
39
+ An async generator that yields tokens from the LangGraph agent
40
+ in a format compatible with Server-Sent Events (SSE).
41
+ """
42
+ # ==========================================
43
+ # 2. THE AGENT LAYER (Heavy Compute)
44
+ # ==========================================
45
+ initial_state = {"messages": [HumanMessage(content=query)]}
46
+ full_answer = "" # We need to collect the tokens to save them later
47
+
48
+ # .astream_events is the key to deep-access streaming in LangChain/LangGraph
49
+ async for event in agent_app.astream_events(initial_state, version="v1"):
50
+ kind = event["event"]
51
+
52
+ # We are looking for the 'on_chat_model_stream' event
53
+ # This triggers every time a new token is generated by the LLM
54
+ if kind == "on_chat_model_stream":
55
+ content = event["data"]["chunk"].content
56
+ if content:
57
+ full_answer += content
58
+ # SSE format requires the "data: " prefix
59
+ yield f"data: {json.dumps({'token': content})}\n\n"
60
+
61
+ # ==========================================
62
+ # 3. SAVE FOR THE FUTURE
63
+ # ==========================================
64
+ # Only cache if we got an answer, AND the answer isn't our fallback failure phrase
65
+ failure_phrase = "I do not have enough information"
66
+
67
+ if full_answer and failure_phrase not in full_answer:
68
+ save_to_cache(query, full_answer)
69
+ else:
70
+ print("\n⚠️ [CACHE SKIP] Agent failed to answer. Did not poison the cache.")
71
+
72
+ @app.post("/chat/stream")
73
+ async def chat_streaming_endpoint(request: ChatRequest):
74
+ return StreamingResponse(
75
+ stream_generator(request.query),
76
+ media_type="text/event-stream"
77
+ )
src/evaluation/judge.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ We are going to create a strict grading script using LangChain's
3
+ **with_structured_output**. This forces our Judge LLM to return a strict JSON
4
+ object containing an integer score (1 for Pass, 0 for Fail) and a reasoning string.
5
+ """
6
+ from pydantic import BaseModel, Field
7
+ from langchain_groq import ChatGroq
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+
13
+ # ==========================================
14
+ # 1. The Strict Grading Schema
15
+ # ==========================================
16
+ class HallucinationScore(BaseModel):
17
+ score: int = Field(description="Return 1 if perfectly grounded. Return 0 if hallucinated.")
18
+ reasoning: str = Field(description="A 1-sentence explanation of why you gave this score.")
19
+
20
+ # ==========================================
21
+ # 2. Initialize the Impartial Judge
22
+ # ==========================================
23
+ # We use temperature=0 because we want strict, deterministic grading, not creativity!
24
+ model_name_1 = "llama-3.1-70b-versatile"
25
+ model_name_2 = "llama-3.1-8b-instant"
26
+
27
+ judge_llm = ChatGroq(model=model_name_2, temperature=0)
28
+ structured_judge = judge_llm.with_structured_output(HallucinationScore)
29
+
30
+ # ==========================================
31
+ # 3. The Grading Rubric (System Prompt)
32
+ # ==========================================
33
+ system_prompt = """You are an impartial AI Compliance Judge evaluating an Agent's response.
34
+ You will be given the 'Retrieved Context' from the database, and the 'Agent Answer'.
35
+ Your ONLY job is to check for HALLUCINATIONS.
36
+
37
+ RULE:
38
+ - If the Agent's answer contains ANY factual information, names, or numbers that are NOT present in the Retrieved Context, score it a 0.
39
+ - If the Agent's answer is strictly based ONLY on the context, score it a 1.
40
+ - Do not grade grammar or tone. Only grade factual grounding.
41
+ """
42
+
43
+ prompt = ChatPromptTemplate.from_messages([
44
+ ("system", system_prompt),
45
+ ("human", "Retrieved Context: \n\n {context} \n\n Agent Answer: \n\n {answer}")
46
+ ])
47
+
48
+ evaluator = prompt | structured_judge
49
+
50
+ def check_hallucination(context: str, answer: str):
51
+ print("\n⚖️ [JUDGE] Evaluating answer for hallucinations...")
52
+ try:
53
+ result = evaluator.invoke({"context": context, "answer": answer})
54
+ return result
55
+ except Exception as e:
56
+ print(f"Judge Error: {e}")
57
+ return None
58
+
59
+
60
+ if __name__ == "__main__":
61
+ # The reality: What our Vector DB actually found.
62
+ simulated_context = (
63
+ "OmniRouter is an AI architecture that routes LLM requests. "
64
+ "It supports OpenAI and Anthropic APIs."
65
+ )
66
+
67
+ print("\n========== TEST 1: The Good Agent ==========")
68
+ good_answer = "OmniRouter routes requests and works with Anthropic and OpenAI."
69
+ good_result = check_hallucination(simulated_context, good_answer)
70
+ print(f"Score: {good_result.score}/1")
71
+ print(f"Reasoning: {good_result.reasoning}")
72
+
73
+ print("\n========== TEST 2: The Hallucinating Agent ==========")
74
+ bad_answer = "OmniRouter routes requests and works with OpenAI, Anthropic, and Google Gemini."
75
+ bad_result = check_hallucination(simulated_context, bad_answer)
76
+ print(f"Score: {bad_result.score}/1")
77
+ print(f"Reasoning: {bad_result.reasoning}")
src/evaluation/run_evals.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
2
+
3
+ # Import your real compiled agent and your real judge
4
+ from src.agent.graph import app
5
+ from src.evaluation.judge import check_hallucination
6
+
7
+ def evaluate_real_agent(query: str):
8
+ print(f"\n==================================================")
9
+ print(f"🚀 RUNNING REAL AGENT EVALUATION")
10
+ print(f"Query: '{query}'")
11
+ print(f"==================================================")
12
+
13
+ # 1. Trigger the Real Agent
14
+ initial_state = {"messages": [HumanMessage(content=query)]}
15
+ config = {"configurable": {"thread_id": "automated_eval_run_1"}}
16
+
17
+ print("\n🤖 Agent is thinking and searching...")
18
+ # We use .invoke() here because we don't need streaming for a backend test
19
+ final_state = app.invoke(initial_state, config)
20
+
21
+ # 2. Extract the Dynamic Data from the State Machine's Memory
22
+ retrieved_context = ""
23
+ final_answer = ""
24
+
25
+ for msg in final_state["messages"]:
26
+ # Find the exact text the ChromaDB tool returned
27
+ if isinstance(msg, ToolMessage):
28
+ retrieved_context += msg.content + "\n"
29
+ # Find the final answer the Agent generated
30
+ elif isinstance(msg, AIMessage) and msg.content:
31
+ final_answer = msg.content
32
+
33
+ if not retrieved_context:
34
+ print("⚠️ Agent did not use the database. Cannot run Hallucination check.")
35
+ return
36
+
37
+ # 3. Pass the dynamic data to the Judge
38
+ result = check_hallucination(context=retrieved_context, answer=final_answer)
39
+
40
+ # 4. Print the final Evaluation Report
41
+ print(f"\n📊 EVALUATION REPORT")
42
+ print(f"Score: {result.score} / 1")
43
+ if result.score == 1:
44
+ print("✅ PASS: Answer is completely grounded in the database.")
45
+ else:
46
+ print("❌ FAIL: Hallucination detected!")
47
+
48
+ print(f"Judge's Reasoning: {result.reasoning}")
49
+ print(f"==================================================\n")
50
+
51
+ if __name__ == "__main__":
52
+ # Test our agent with a real query!
53
+ evaluate_real_agent("What is OmniRouter and what does it do?")
src/providers/__init__.py ADDED
File without changes
src/providers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (136 Bytes). View file
 
src/providers/__pycache__/anthropic_client.cpython-312.pyc ADDED
Binary file (2.92 kB). View file
 
src/providers/__pycache__/base.cpython-312.pyc ADDED
Binary file (1.79 kB). View file
 
src/providers/__pycache__/openai_client.cpython-312.pyc ADDED
Binary file (3.23 kB). View file
 
src/providers/anthropic_client.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from anthropic import AsyncAnthropic
2
+ import logging
3
+
4
+ from src.schemas import RouterConfig, LLMResponse
5
+ from src.providers.base import BaseLLMProvider
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class AnthropicProvider(BaseLLMProvider):
10
+ def __init__(self, api_key: str):
11
+ super().__init__(api_key)
12
+ self.client = AsyncAnthropic(api_key=self.api_key)
13
+
14
+ async def async_generate(self, prompt: str, config: RouterConfig) -> LLMResponse:
15
+ logger.info(f"Routing request to Anthropic using model: {config.model}")
16
+
17
+ # Anthropic's API structure is slightly different from OpenAI's
18
+ response = await self.client.messages.create(
19
+ model=config.model,
20
+ max_tokens=1024, # Anthropic requires max_tokens to be explicitly set
21
+ messages=[{"role": "user", "content": prompt}],
22
+ temperature=config.temperature,
23
+ )
24
+
25
+ content = response.content[0].text
26
+ prompt_tokens = response.usage.input_tokens
27
+ completion_tokens = response.usage.output_tokens
28
+
29
+ cost = self.calculate_cost(prompt_tokens, completion_tokens, config.model)
30
+
31
+ return LLMResponse(
32
+ content=content,
33
+ provider_used="anthropic",
34
+ model_used=config.model,
35
+ prompt_tokens=prompt_tokens,
36
+ completion_tokens=completion_tokens,
37
+ cost_estimate=cost
38
+ )
39
+
40
+ def calculate_cost(self, prompt_tokens: int, completion_tokens: int, model_name: str) -> float:
41
+ pricing = {
42
+ "claude-3-opus-20240229": {"prompt": 15.0, "completion": 75.0},
43
+ "claude-3-5-sonnet-20240620": {"prompt": 3.0, "completion": 15.0},
44
+ "claude-3-haiku-20240307": {"prompt": 0.25, "completion": 1.25}
45
+ }
46
+ rates = pricing.get(model_name, {"prompt": 0.0, "completion": 0.0})
47
+ cost = (prompt_tokens / 1_000_000) * rates["prompt"] + \
48
+ (completion_tokens / 1_000_000) * rates["completion"]
49
+ return round(cost, 6)
src/providers/base.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The abstract blueprint every provider must follow
3
+ """
4
+ from abc import ABC, abstractmethod
5
+ from typing import Optional
6
+ from src.schemas import RouterConfig, LLMResponse
7
+
8
+ class BaseLLMProvider(ABC):
9
+ """
10
+ The strict blueprint that ALL LLM providers must follow.
11
+ If a developer tries to create a provider without an 'async_generate' method,
12
+ Python will throw a TypeError upon instantiation.
13
+ """
14
+
15
+ def __init__(self, api_key: str):
16
+ # Every provider needs an API key (or a dummy key for local models)
17
+ self.api_key = api_key
18
+
19
+ @abstractmethod
20
+ async def async_generate(
21
+ self,
22
+ prompt: str,
23
+ config: RouterConfig
24
+ ) -> LLMResponse:
25
+ """
26
+ The core engine method.
27
+ Takes a string prompt and our strict RouterConfig.
28
+ MUST return our strictly typed LLMResponse.
29
+ """
30
+ pass
31
+
32
+ @abstractmethod
33
+ def calculate_cost(
34
+ self,
35
+ prompt_tokens: int,
36
+ completion_tokens: int,
37
+ model_name: str
38
+ ) -> float:
39
+ """
40
+ Calculates the estimated cost of the API call.
41
+ Essential for production monitoring.
42
+ """
43
+ pass
src/providers/openai_client.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import AsyncOpenAI
2
+ import logging
3
+
4
+ # Import our strict schemas and base blueprint
5
+ from src.schemas import RouterConfig, LLMResponse
6
+ from src.providers.base import BaseLLMProvider
7
+
8
+ # Set up logging for professional debugging
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class OpenAIProvider(BaseLLMProvider):
12
+ """
13
+ The concrete implementation for OpenAI's API.
14
+ How it strictly fulfills the contract defined in BaseLLMProvider.
15
+ """
16
+
17
+ def __init__(self, api_key: str):
18
+ # Call the parent class initialization
19
+ super().__init__(api_key)
20
+
21
+ # CRITICAL: We initialize the ASYNC client, not the standard synchronous one.
22
+ # This is what allows our router to handle hundreds of concurrent requests.
23
+ self.client = AsyncOpenAI(api_key=self.api_key)
24
+
25
+ async def async_generate(self, prompt: str, config: RouterConfig) -> LLMResponse:
26
+ logger.info(f"Routing request to OpenAI using model: {config.model}")
27
+
28
+ # 1. Execute the Async API Call
29
+ response = await self.client.chat.completions.create(
30
+ model=config.model,
31
+ messages=[{"role": "user", "content": prompt}],
32
+ temperature=config.temperature,
33
+ # We will add top_p, frequency_penalty, etc. later as needed
34
+ )
35
+
36
+ # 2. Extract Data from OpenAI's specific object structure
37
+ content = response.choices[0].message.content
38
+ prompt_tokens = response.usage.prompt_tokens
39
+ completion_tokens = response.usage.completion_tokens
40
+
41
+ # 3. Calculate Cost dynamically
42
+ cost = self.calculate_cost(prompt_tokens, completion_tokens, config.model)
43
+
44
+ # 4. Standardize the Output
45
+ # We transform OpenAI's proprietary response into our universal LLMResponse schema.
46
+ # Now, the rest of our application doesn't need to know anything about OpenAI's specific formatting.
47
+ return LLMResponse(
48
+ content=content,
49
+ provider_used="openai",
50
+ model_used=config.model,
51
+ prompt_tokens=prompt_tokens,
52
+ completion_tokens=completion_tokens,
53
+ cost_estimate=cost
54
+ )
55
+
56
+ def calculate_cost(self, prompt_tokens: int, completion_tokens: int, model_name: str) -> float:
57
+ """
58
+ Calculates the exact cost of the API call based on OpenAI's pricing (per 1M tokens).
59
+ This is a massive value-add for your open-source repository.
60
+ """
61
+ # A dictionary acting as a simple pricing database
62
+ pricing = {
63
+ "gpt-4-turbo": {"prompt": 10.0, "completion": 30.0},
64
+ "gpt-4o": {"prompt": 5.0, "completion": 15.0},
65
+ "gpt-3.5-turbo": {"prompt": 0.5, "completion": 1.5}
66
+ }
67
+
68
+ # If they use a model not in our dict, default to 0.0 to prevent crashes
69
+ rates = pricing.get(model_name, {"prompt": 0.0, "completion": 0.0})
70
+
71
+ # Math: (tokens / 1,000,000) * rate
72
+ cost = (prompt_tokens / 1_000_000) * rates["prompt"] + \
73
+ (completion_tokens / 1_000_000) * rates["completion"]
74
+
75
+ return round(cost, 6)
src/rag/__pycache__/chatbot.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
src/rag/__pycache__/ingestion.cpython-312.pyc ADDED
Binary file (1.66 kB). View file
 
src/rag/__pycache__/vector_store.cpython-312.pyc ADDED
Binary file (1.89 kB). View file
 
src/rag/chatbot.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
3
+ from langchain.chains.combine_documents import create_stuff_documents_chain
4
+ # from langchain.chains.history_aware_retriever import create_history_aware_retriever
5
+ # from langchain.chains.retrieval import create_retrieval_chain
6
+ # from langchain.chains.combine_documents import create_stuff_documents_chain
7
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
8
+
9
+ from src.rag.vector_store import get_vector_store
10
+
11
+ def build_doc_assistant(api_key: str):
12
+ """
13
+ Constructs the conversational RAG pipeline.
14
+ """
15
+ # 1. Initialize our LLM (temperature=0 because we want factual answers, not creative ones)
16
+ llm = ChatOpenAI(api_key=api_key, model="gpt-3.5-turbo", temperature=0)
17
+
18
+ # 2. Connect to our Vector DB (k=2 means return the top 2 most relevant chunks)
19
+ retriever = get_vector_store(api_key).as_retriever(search_kwargs={"k": 2})
20
+
21
+ # ==========================================
22
+ # STEP 1: The "Question Reformulation" Prompt
23
+ # ==========================================
24
+ contextualize_q_system_prompt = (
25
+ "Given a chat history and the latest user question "
26
+ "which might reference context in the chat history, "
27
+ "formulate a standalone question which can be understood "
28
+ "without the chat history. Do NOT answer the question, "
29
+ "just reformulate it if needed and otherwise return it as is."
30
+ )
31
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
32
+ ("system", contextualize_q_system_prompt),
33
+ MessagesPlaceholder("chat_history"), # Injects our memory here
34
+ ("human", "{input}"),
35
+ ])
36
+
37
+ # This chain automatically handles rewriting the query before searching
38
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
39
+
40
+ # ==========================================
41
+ # STEP 2: The "Final Answer" Prompt
42
+ # ==========================================
43
+ system_prompt = (
44
+ "You are an elite AI Engineering Assistant. "
45
+ "Use the following pieces of retrieved context to answer the question. "
46
+ "If the answer is not contained in the context, say 'I don't know based on the documentation.' "
47
+ "Do not make up an answer. Keep it concise.\n\n"
48
+ "Context: {context}"
49
+ )
50
+ qa_prompt = ChatPromptTemplate.from_messages([
51
+ ("system", system_prompt),
52
+ MessagesPlaceholder("chat_history"),
53
+ ("human", "{input}"),
54
+ ])
55
+
56
+ # This chain handles injecting the retrieved chunks into the {context} variable
57
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
58
+
59
+ # ==========================================
60
+ # STEP 3: Tie it all together
61
+ # ==========================================
62
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
63
+ return rag_chain
src/rag/ingestion.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
2
+ from langchain_core.documents import Document
3
+
4
+ def chunk_document_text(raw_text: str):
5
+ """
6
+ Simulates taking a massive document and chunking it for a Vector Store.
7
+ """
8
+ print(f"Original Document Length: {len(raw_text)} characters")
9
+
10
+ # THE CHUNKER CONFIGURATION
11
+ text_splitter = RecursiveCharacterTextSplitter(
12
+ chunk_size=100, # The maximum size of each chunk
13
+ chunk_overlap=20, # How much the chunks should overlap
14
+ length_function=len,
15
+ separators=["\n\n", "\n", " ", ""] # Tries to split at paragraphs first, then sentences
16
+ )
17
+
18
+ # Create a LangChain Document object
19
+ doc = Document(page_content=raw_text, metadata={"source": "engineering_manual.pdf"})
20
+
21
+ # Execute the split
22
+ chunks = text_splitter.split_documents([doc])
23
+
24
+ print(f"\nCreated {len(chunks)} chunks.")
25
+
26
+ # Let's inspect the exact output to understand the data structure
27
+ for i, chunk in enumerate(chunks):
28
+ print(f"\n--- Chunk {i+1} ---")
29
+ print(chunk.page_content)
30
+
31
+ return chunks
32
+ # Let's test it with a sample "manual"
33
+ if __name__ == "__main__":
34
+ sample_manual = (
35
+ "OmniRouter is an advanced asynchronous LLM routing engine. "
36
+ "It is designed to handle multiple providers gracefully. "
37
+ "If the primary provider fails, the system initiates a failover protocol. "
38
+ "This ensures maximum uptime for production systems."
39
+ )
40
+
41
+ chunk_document_text(sample_manual)
src/rag/vector_store.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from langchain_core.documents import Document
4
+ from langchain_community.vectorstores import Chroma
5
+
6
+ # OpenAI embedding
7
+ # from langchain_openai import OpenAIEmbeddings
8
+
9
+ # Free local embedding
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
+
15
+ # # Huggingface api key...
16
+ # os.environ["HF_TOKEN"] = "hf_PWDT"
17
+
18
+ # This is where our local database will be saved on your hard drive
19
+ DB_DIRECTORY = "./chroma_db"
20
+
21
+ def get_embeddings_model():
22
+ """Returns the active embedding model."""
23
+ # --- FREE PIPELINE ---
24
+ # This downloads a small, highly efficient open-source model to your machine.
25
+ print("Loading HuggingFace Embeddings...")
26
+ # api_key = os.getenv("HUGGINGFACE_API_KEY")
27
+ return HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
28
+
29
+ # --- PAID PIPELINE (Uncomment when you have credits) ---
30
+ # We use OpenAI's embedding model here. It converts text to 1536-dimensional vectors.
31
+ # api_key = os.getenv("OPENAI_API_KEY")
32
+ # return OpenAIEmbeddings(api_key=api_key, model="text-embedding-3-small")
33
+
34
+ def build_vector_store(chunks: List[Document], api_key: str):
35
+ """
36
+ Takes a list of chunked documents, embeds them, and saves them to a local Chroma database.
37
+ """
38
+ embeddings = get_embeddings_model()
39
+
40
+ print(f"Embedding {len(chunks)} chunks and saving to {DB_DIRECTORY}...")
41
+
42
+ # 1. Create the database
43
+ # 2. Embed all the chunks
44
+ # 3. Save it to the DB_DIRECTORY
45
+ vector_store = Chroma.from_documents(
46
+ documents=chunks,
47
+ embedding=embeddings,
48
+ persist_directory=DB_DIRECTORY
49
+ )
50
+
51
+ # Force the database to save to disk
52
+ vector_store.persist()
53
+ print("Database successfully built and saved to disk!")
54
+ return vector_store
55
+
56
+ def get_vector_store(api_key: str):
57
+ """
58
+ Retrieves the existing database from the hard drive so we don't have to rebuild it every time.
59
+ """
60
+ embeddings = get_embeddings_model()
61
+ return Chroma(persist_directory=DB_DIRECTORY, embedding_function=embeddings)
src/router.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from typing import Dict, Any
4
+
5
+ # Import our schemas and providers
6
+ from src.schemas import RouterConfig, LLMResponse
7
+ from src.providers.base import BaseLLMProvider
8
+ from src.providers.openai_client import OpenAIProvider
9
+ from src.providers.anthropic_client import AnthropicProvider
10
+
11
+ # Set up logging so we can see the retries happening in the terminal
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class OmniRouter:
16
+ """
17
+ The central routing engine.
18
+ Handles provider selection, retries, and error management.
19
+ """
20
+
21
+ def __init__(self, api_keys: Dict[str, str]):
22
+ """
23
+ We initialize the router with a dictionary of API keys.
24
+ We then map string names (like 'openai') to their concrete class instances.
25
+ """
26
+ self.providers: Dict[str, BaseLLMProvider] = {}
27
+
28
+ # If the user passed an OpenAI key, activate the OpenAI provider
29
+ if "openai" in api_keys:
30
+ self.providers["openai"] = OpenAIProvider(api_key=api_keys["openai"])
31
+
32
+ #---Register Anthropic
33
+ if "anthropic" in api_keys:
34
+ self.providers["anthropic"] = AnthropicProvider(api_key=api_keys["anthropic"])
35
+ # We will add others here later
36
+
37
+ async def generate(self, prompt: str, config: RouterConfig) -> LLMResponse:
38
+ """
39
+ The main entry point. Routes the prompt to the correct provider with retries.
40
+ """
41
+ # 1. Check if the requested provider actually exists in our dictionary
42
+ provider = self.providers.get(config.provider)
43
+ if not provider:
44
+ raise ValueError(f"Provider '{config.provider}' is not configured.")
45
+
46
+ last_exception = None
47
+
48
+ # 2. PRIMARY RETRY LOOP
49
+ for attempt in range(config.max_retries):
50
+ try:
51
+ # If this is a retry, log it
52
+ if attempt > 0:
53
+ logger.info(f"[{config.provider}] Retrying... Attempt {attempt + 1} of {config.max_retries}")
54
+
55
+ # 3. The actual API call to whatever provider is currently selected
56
+ response = await provider.async_generate(prompt, config)
57
+ return response
58
+
59
+ except Exception as e:
60
+ # If the API crashes, we catch it here instead of crashing the app
61
+ logger.warning(f"[{config.provider}] Attempt {attempt + 1} failed with error: {str(e)}")
62
+ last_exception = e
63
+
64
+ # 4. EXPONENTIAL BACKOFF
65
+ # Wait 2^attempt seconds (1s, 2s, 4s, 8s...) before trying again
66
+ wait_time = 2 ** attempt
67
+ logger.info(f"Waiting {wait_time} seconds before next attempt...")
68
+ await asyncio.sleep(wait_time)
69
+
70
+ # 2. FAILOVER LOGIC (The Holy Grail)
71
+ if config.fallback_provider:
72
+ logger.error(f"🚨 Primary provider '{config.provider}' exhausted all retries. Initiating FAILOVER to '{config.fallback_provider}'...")
73
+
74
+ # Create a new config for the fallback provider
75
+ fallback_config = RouterConfig(
76
+ provider=config.fallback_provider,
77
+ model=config.fallback_model or config.model, # Use specific fallback model if provided
78
+ temperature=config.temperature,
79
+ max_retries=config.max_retries
80
+ )
81
+
82
+ # Recursively call generate with the new config!
83
+ return await self.generate(prompt, fallback_config)
84
+
85
+ # 5. If we loop through all max_retries and still fail, crash gracefully
86
+ logger.error(f"All {config.max_retries} attempts failed and no fallback configured.")
87
+ raise last_exception
src/schemas.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In LLMs, you are querying a probabilistic text engine.
3
+ If you ask it for an age, it might give you 25, or it might give you "Twenty-five", or it might say "Based on the data, the user is 25 years old."
4
+
5
+ If your system expects 25 but gets a whole sentence, your code crashes in production.
6
+
7
+ **Schemas act as the bouncers at the door of your application. We use Pydantic to define the exact shape of the data we expect.**
8
+ """
9
+ from pydantic import BaseModel, Field
10
+ from typing import Dict, Any, Optional
11
+
12
+ # 1. THE CONFIGURATION SCHEMA
13
+ # This dictates how our router behaves. Notice how we set smart defaults.
14
+ class RouterConfig(BaseModel):
15
+ """Configuration for how the OmniRouter should route the request."""
16
+ provider: str = Field(
17
+ default="openai",
18
+ description="The LLM provider to use (e.g., 'openai', 'anthropic', 'local')"
19
+ )
20
+ model: str = Field(
21
+ default="gpt-4-turbo",
22
+ description="The specific model string to use"
23
+ )
24
+ # Defensive Engineering: An LLM temperature cannot be less than 0 or greater than 2.
25
+ # ge = greater than or equal to, le = less than or equal to.
26
+ temperature: float = Field(
27
+ default=0.7,
28
+ ge=0.0,
29
+ le=2.0,
30
+ description="Creativity score for the model"
31
+ )
32
+ max_retries: int = Field(
33
+ default=3,
34
+ description="How many times to retry on API failure or rate limit"
35
+ )
36
+
37
+ # --- NEW CAPABILITY --- Added
38
+ fallback_provider: Optional[str] = Field(
39
+ default=None,
40
+ description="If the primary provider completely fails, switch to this one"
41
+ )
42
+ fallback_model: Optional[str] = Field(
43
+ default=None,
44
+ description="The model to use for the fallback provider"
45
+ )
46
+
47
+ # 2. THE STANDARDIZED OUTPUT SCHEMA
48
+ # This solves the main pain point from our README.
49
+ # Whether OpenAI or Anthropic answers, the rest of our app gets THIS exact object.
50
+ class LLMResponse(BaseModel):
51
+ """The standardized output format returned from ANY provider."""
52
+ content: str = Field(description="The actual text response generated by the LLM")
53
+ provider_used: str = Field(description="Which provider actually generated this response")
54
+ model_used: str = Field(description="The specific model used")
55
+
56
+ # We track tokens heavily for cost optimization (Week 14 concept)
57
+ prompt_tokens: int = Field(default=0, description="Tokens used in the prompt")
58
+ completion_tokens: int = Field(default=0, description="Tokens used in the completion")
59
+
60
+ # We will calculate this automatically later
61
+ cost_estimate: float = Field(default=0.0, description="Estimated cost of this call in USD")