LLMVis / utils /openrouter_client.py
cdpearlman's picture
Chatbot improvements: UI cleaning, chat streaming, typing animation, and better markdown formatting
a684a8d
"""
OpenRouter API Client
Wrapper for OpenRouter API providing text generation and embedding capabilities
for the AI chatbot feature.
Uses the OpenAI-compatible API via requests.
"""
import os
import requests
from typing import List, Dict, Optional
# =============================================================================
# GLOBAL MODEL CONFIGURATION
# =============================================================================
# Change these to switch models across the entire application
# Chat model: Gemini 2.5 Flash Lite - $0.10/$0.40 per 1M tokens, 1M context
DEFAULT_CHAT_MODEL = "google/gemini-2.5-flash-lite"
# Embedding model: text-embedding-3-small - $0.02 per 1M tokens, 1536 dimensions
DEFAULT_EMBEDDING_MODEL = "openai/text-embedding-3-small"
# =============================================================================
# OpenRouter API endpoint
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
# System prompt for the chatbot
SYSTEM_PROMPT = """You are a helpful AI assistant integrated into a Transformer Explanation Dashboard.
Your role is to help users understand how transformer models work, explain the experiments
available in the dashboard, and answer questions about machine learning concepts.
You have access to:
1. RAG documents containing information about transformers and the dashboard
2. The current state of the dashboard (selected model, prompt, analysis results)
When answering:
- Be extremely concise. Directly answer the user's question.
- Do not provide exhaustive explanations unless explicitly asked.
- At the end of your concise answer, offer to go into more detail (e.g., "Let me know if you'd like me to explain [TOPIC] in more detail.")
- Be clear and educational
- Use examples when helpful
- Reference specific dashboard features when relevant
- Format code snippets properly with markdown code blocks
- If you don't know something, say so honestly
Dashboard context will be provided in the user's messages when available."""
class OpenRouterClient:
"""Client for interacting with OpenRouter API."""
def __init__(self, api_key: Optional[str] = None):
"""
Initialize the OpenRouter client.
Args:
api_key: OpenRouter API key. If not provided, reads from OPENROUTER_API_KEY env var.
"""
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY")
self._initialized = False
if self.api_key:
self._initialize()
def _initialize(self):
"""Initialize the OpenRouter API client."""
if not self.api_key:
return
self._headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://transformer-dashboard.local", # Optional: for rankings
"X-Title": "Transformer Explanation Dashboard" # Optional: for rankings
}
self._initialized = True
@property
def is_available(self) -> bool:
"""Check if the OpenRouter API is available and configured."""
return self._initialized and self.api_key is not None
def generate_response(
self,
user_message: str,
chat_history: Optional[List[Dict[str, str]]] = None,
rag_context: Optional[str] = None,
dashboard_context: Optional[Dict] = None
) -> str:
"""
Generate a response using OpenRouter.
Args:
user_message: The user's message
chat_history: List of previous messages [{"role": "user/assistant", "content": "..."}]
rag_context: Retrieved context from RAG documents
dashboard_context: Current dashboard state (model, prompt, results)
Returns:
Generated response text
"""
if not self.is_available:
return "Sorry, the AI assistant is not available. Please check that the OPENROUTER_API_KEY environment variable is set."
try:
# Build the full prompt with context
full_message = self._build_prompt(user_message, rag_context, dashboard_context)
# Build messages array with system prompt and history
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add chat history
if chat_history:
for msg in chat_history[-10:]: # Keep last 10 messages for context
role = "user" if msg.get("role") == "user" else "assistant"
messages.append({
"role": role,
"content": msg.get("content", "")
})
# Add the current user message
messages.append({"role": "user", "content": full_message})
# Make API request
response = requests.post(
f"{OPENROUTER_BASE_URL}/chat/completions",
headers=self._headers,
json={
"model": DEFAULT_CHAT_MODEL,
"messages": messages
},
timeout=60
)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except requests.exceptions.HTTPError as e:
error_msg = str(e)
if e.response is not None:
try:
error_data = e.response.json()
error_msg = error_data.get("error", {}).get("message", str(e))
except:
pass
if "rate" in error_msg.lower() or "429" in error_msg:
return f"The AI service is currently rate limited. Please try again in a moment. {error_msg}"
elif "401" in error_msg or "invalid" in error_msg.lower():
return "Invalid API key. Please check your OPENROUTER_API_KEY configuration."
else:
print(f"OpenRouter API error: {e}")
return f"Sorry, I encountered an error: {error_msg}"
except Exception as e:
print(f"OpenRouter API error: {e}")
return f"Sorry, I encountered an error: {str(e)}"
def generate_stream(
self,
user_message: str,
chat_history: Optional[List[Dict[str, str]]] = None,
rag_context: Optional[str] = None,
dashboard_context: Optional[Dict] = None
):
"""
Generate a streaming response using OpenRouter.
Yields text chunks as they arrive.
"""
if not self.is_available:
yield "Sorry, the AI assistant is not available. Please check that the OPENROUTER_API_KEY environment variable is set."
return
try:
full_message = self._build_prompt(user_message, rag_context, dashboard_context)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
if chat_history:
for msg in chat_history[-10:]:
role = "user" if msg.get("role") == "user" else "assistant"
messages.append({"role": role, "content": msg.get("content", "")})
messages.append({"role": "user", "content": full_message})
response = requests.post(
f"{OPENROUTER_BASE_URL}/chat/completions",
headers=self._headers,
json={
"model": DEFAULT_CHAT_MODEL,
"messages": messages,
"stream": True
},
timeout=60,
stream=True
)
response.raise_for_status()
import json
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: ') and line != 'data: [DONE]':
try:
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "content" in delta:
yield delta["content"]
except json.JSONDecodeError:
continue
except requests.exceptions.HTTPError as e:
error_msg = str(e)
if e.response is not None:
try:
error_data = e.response.json()
error_msg = error_data.get("error", {}).get("message", str(e))
except:
pass
if "rate" in error_msg.lower() or "429" in error_msg:
yield f"The AI service is currently rate limited. Please try again in a moment. {error_msg}"
elif "401" in error_msg or "invalid" in error_msg.lower():
yield "Invalid API key. Please check your OPENROUTER_API_KEY configuration."
else:
print(f"OpenRouter API stream error: {e}")
yield f"Sorry, I encountered an error: {error_msg}"
except Exception as e:
print(f"OpenRouter API stream error: {e}")
yield f"Sorry, I encountered an error: {str(e)}"
def _build_prompt(
self,
user_message: str,
rag_context: Optional[str] = None,
dashboard_context: Optional[Dict] = None
) -> str:
"""Build the full prompt with context."""
parts = []
# Add dashboard context if available
if dashboard_context:
context_str = self._format_dashboard_context(dashboard_context)
if context_str:
parts.append(f"**Current Dashboard State:**\n{context_str}\n")
# Add RAG context if available
if rag_context:
parts.append(f"**Relevant Documentation:**\n{rag_context}\n")
# Add the user's message
parts.append(f"**User Question:**\n{user_message}")
return "\n".join(parts)
def _format_dashboard_context(self, context: Dict) -> str:
"""Format dashboard context for the prompt."""
lines = []
if context.get("model"):
lines.append(f"- Selected Model: {context['model']}")
if context.get("prompt"):
lines.append(f"- Input Prompt: \"{context['prompt']}\"")
if context.get("predicted_token"):
prob = context.get("predicted_probability", 0)
lines.append(f"- Predicted Next Token: \"{context['predicted_token']}\" (probability: {prob:.1%})")
if context.get("top_predictions"):
top = context["top_predictions"][:5]
tokens_str = ", ".join([f"{t['token']} ({t['probability']:.1%})" for t in top])
lines.append(f"- Top Predictions: {tokens_str}")
if context.get("ablated_heads"):
heads_str = ", ".join([f"L{h['layer']}H{h['head']}" for h in context["ablated_heads"]])
lines.append(f"- Ablated Attention Heads: {heads_str}")
return "\n".join(lines)
def get_embedding(self, text: str) -> Optional[List[float]]:
"""
Get embedding vector for text using OpenRouter Embedding API.
Args:
text: Text to embed
Returns:
Embedding vector as list of floats, or None if failed
"""
if not self.is_available:
return None
try:
response = requests.post(
f"{OPENROUTER_BASE_URL}/embeddings",
headers=self._headers,
json={
"model": DEFAULT_EMBEDDING_MODEL,
"input": text
},
timeout=30
)
response.raise_for_status()
data = response.json()
return data["data"][0]["embedding"]
except Exception as e:
print(f"Embedding error: {e}")
return None
def get_query_embedding(self, query: str) -> Optional[List[float]]:
"""
Get embedding vector for a query.
Note: OpenRouter doesn't have separate task types for embeddings,
so this calls the same endpoint as get_embedding.
Args:
query: Query text to embed
Returns:
Embedding vector as list of floats, or None if failed
"""
return self.get_embedding(query)
# Singleton instance
_client_instance: Optional[OpenRouterClient] = None
def get_openrouter_client() -> OpenRouterClient:
"""Get or create the singleton OpenRouter client instance."""
global _client_instance
if _client_instance is None:
_client_instance = OpenRouterClient()
return _client_instance
def generate_response(
user_message: str,
chat_history: Optional[List[Dict[str, str]]] = None,
rag_context: Optional[str] = None,
dashboard_context: Optional[Dict] = None
) -> str:
"""
Convenience function to generate a response.
Args:
user_message: The user's message
chat_history: Previous chat messages
rag_context: Retrieved RAG context
dashboard_context: Current dashboard state
Returns:
Generated response text
"""
client = get_openrouter_client()
return client.generate_response(user_message, chat_history, rag_context, dashboard_context)
def generate_stream(
user_message: str,
chat_history: Optional[List[Dict[str, str]]] = None,
rag_context: Optional[str] = None,
dashboard_context: Optional[Dict] = None
):
"""
Convenience function to generate a streaming response.
Args:
user_message: The user's message
chat_history: Previous chat messages
rag_context: Retrieved RAG context
dashboard_context: Current dashboard state
Returns:
Generator yielding text chunks
"""
client = get_openrouter_client()
return client.generate_stream(user_message, chat_history, rag_context, dashboard_context)
def get_embedding(text: str) -> Optional[List[float]]:
"""Convenience function to get document embedding."""
client = get_openrouter_client()
return client.get_embedding(text)
def get_query_embedding(query: str) -> Optional[List[float]]:
"""Convenience function to get query embedding."""
client = get_openrouter_client()
return client.get_query_embedding(query)
# Backward compatibility aliases (for gradual migration)
GeminiClient = OpenRouterClient
get_gemini_client = get_openrouter_client