Spaces:
Running
Running
Chatbot improvements: UI cleaning, chat streaming, typing animation, and better markdown formatting
a684a8d | """ | |
| OpenRouter API Client | |
| Wrapper for OpenRouter API providing text generation and embedding capabilities | |
| for the AI chatbot feature. | |
| Uses the OpenAI-compatible API via requests. | |
| """ | |
| import os | |
| import requests | |
| from typing import List, Dict, Optional | |
| # ============================================================================= | |
| # GLOBAL MODEL CONFIGURATION | |
| # ============================================================================= | |
| # Change these to switch models across the entire application | |
| # Chat model: Gemini 2.5 Flash Lite - $0.10/$0.40 per 1M tokens, 1M context | |
| DEFAULT_CHAT_MODEL = "google/gemini-2.5-flash-lite" | |
| # Embedding model: text-embedding-3-small - $0.02 per 1M tokens, 1536 dimensions | |
| DEFAULT_EMBEDDING_MODEL = "openai/text-embedding-3-small" | |
| # ============================================================================= | |
| # OpenRouter API endpoint | |
| OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" | |
| # System prompt for the chatbot | |
| SYSTEM_PROMPT = """You are a helpful AI assistant integrated into a Transformer Explanation Dashboard. | |
| Your role is to help users understand how transformer models work, explain the experiments | |
| available in the dashboard, and answer questions about machine learning concepts. | |
| You have access to: | |
| 1. RAG documents containing information about transformers and the dashboard | |
| 2. The current state of the dashboard (selected model, prompt, analysis results) | |
| When answering: | |
| - Be extremely concise. Directly answer the user's question. | |
| - Do not provide exhaustive explanations unless explicitly asked. | |
| - At the end of your concise answer, offer to go into more detail (e.g., "Let me know if you'd like me to explain [TOPIC] in more detail.") | |
| - Be clear and educational | |
| - Use examples when helpful | |
| - Reference specific dashboard features when relevant | |
| - Format code snippets properly with markdown code blocks | |
| - If you don't know something, say so honestly | |
| Dashboard context will be provided in the user's messages when available.""" | |
| class OpenRouterClient: | |
| """Client for interacting with OpenRouter API.""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| """ | |
| Initialize the OpenRouter client. | |
| Args: | |
| api_key: OpenRouter API key. If not provided, reads from OPENROUTER_API_KEY env var. | |
| """ | |
| self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY") | |
| self._initialized = False | |
| if self.api_key: | |
| self._initialize() | |
| def _initialize(self): | |
| """Initialize the OpenRouter API client.""" | |
| if not self.api_key: | |
| return | |
| self._headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| "HTTP-Referer": "https://transformer-dashboard.local", # Optional: for rankings | |
| "X-Title": "Transformer Explanation Dashboard" # Optional: for rankings | |
| } | |
| self._initialized = True | |
| def is_available(self) -> bool: | |
| """Check if the OpenRouter API is available and configured.""" | |
| return self._initialized and self.api_key is not None | |
| def generate_response( | |
| self, | |
| user_message: str, | |
| chat_history: Optional[List[Dict[str, str]]] = None, | |
| rag_context: Optional[str] = None, | |
| dashboard_context: Optional[Dict] = None | |
| ) -> str: | |
| """ | |
| Generate a response using OpenRouter. | |
| Args: | |
| user_message: The user's message | |
| chat_history: List of previous messages [{"role": "user/assistant", "content": "..."}] | |
| rag_context: Retrieved context from RAG documents | |
| dashboard_context: Current dashboard state (model, prompt, results) | |
| Returns: | |
| Generated response text | |
| """ | |
| if not self.is_available: | |
| return "Sorry, the AI assistant is not available. Please check that the OPENROUTER_API_KEY environment variable is set." | |
| try: | |
| # Build the full prompt with context | |
| full_message = self._build_prompt(user_message, rag_context, dashboard_context) | |
| # Build messages array with system prompt and history | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| # Add chat history | |
| if chat_history: | |
| for msg in chat_history[-10:]: # Keep last 10 messages for context | |
| role = "user" if msg.get("role") == "user" else "assistant" | |
| messages.append({ | |
| "role": role, | |
| "content": msg.get("content", "") | |
| }) | |
| # Add the current user message | |
| messages.append({"role": "user", "content": full_message}) | |
| # Make API request | |
| response = requests.post( | |
| f"{OPENROUTER_BASE_URL}/chat/completions", | |
| headers=self._headers, | |
| json={ | |
| "model": DEFAULT_CHAT_MODEL, | |
| "messages": messages | |
| }, | |
| timeout=60 | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["choices"][0]["message"]["content"] | |
| except requests.exceptions.HTTPError as e: | |
| error_msg = str(e) | |
| if e.response is not None: | |
| try: | |
| error_data = e.response.json() | |
| error_msg = error_data.get("error", {}).get("message", str(e)) | |
| except: | |
| pass | |
| if "rate" in error_msg.lower() or "429" in error_msg: | |
| return f"The AI service is currently rate limited. Please try again in a moment. {error_msg}" | |
| elif "401" in error_msg or "invalid" in error_msg.lower(): | |
| return "Invalid API key. Please check your OPENROUTER_API_KEY configuration." | |
| else: | |
| print(f"OpenRouter API error: {e}") | |
| return f"Sorry, I encountered an error: {error_msg}" | |
| except Exception as e: | |
| print(f"OpenRouter API error: {e}") | |
| return f"Sorry, I encountered an error: {str(e)}" | |
| def generate_stream( | |
| self, | |
| user_message: str, | |
| chat_history: Optional[List[Dict[str, str]]] = None, | |
| rag_context: Optional[str] = None, | |
| dashboard_context: Optional[Dict] = None | |
| ): | |
| """ | |
| Generate a streaming response using OpenRouter. | |
| Yields text chunks as they arrive. | |
| """ | |
| if not self.is_available: | |
| yield "Sorry, the AI assistant is not available. Please check that the OPENROUTER_API_KEY environment variable is set." | |
| return | |
| try: | |
| full_message = self._build_prompt(user_message, rag_context, dashboard_context) | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| if chat_history: | |
| for msg in chat_history[-10:]: | |
| role = "user" if msg.get("role") == "user" else "assistant" | |
| messages.append({"role": role, "content": msg.get("content", "")}) | |
| messages.append({"role": "user", "content": full_message}) | |
| response = requests.post( | |
| f"{OPENROUTER_BASE_URL}/chat/completions", | |
| headers=self._headers, | |
| json={ | |
| "model": DEFAULT_CHAT_MODEL, | |
| "messages": messages, | |
| "stream": True | |
| }, | |
| timeout=60, | |
| stream=True | |
| ) | |
| response.raise_for_status() | |
| import json | |
| for line in response.iter_lines(): | |
| if line: | |
| line = line.decode('utf-8') | |
| if line.startswith('data: ') and line != 'data: [DONE]': | |
| try: | |
| data = json.loads(line[6:]) | |
| if "choices" in data and len(data["choices"]) > 0: | |
| delta = data["choices"][0].get("delta", {}) | |
| if "content" in delta: | |
| yield delta["content"] | |
| except json.JSONDecodeError: | |
| continue | |
| except requests.exceptions.HTTPError as e: | |
| error_msg = str(e) | |
| if e.response is not None: | |
| try: | |
| error_data = e.response.json() | |
| error_msg = error_data.get("error", {}).get("message", str(e)) | |
| except: | |
| pass | |
| if "rate" in error_msg.lower() or "429" in error_msg: | |
| yield f"The AI service is currently rate limited. Please try again in a moment. {error_msg}" | |
| elif "401" in error_msg or "invalid" in error_msg.lower(): | |
| yield "Invalid API key. Please check your OPENROUTER_API_KEY configuration." | |
| else: | |
| print(f"OpenRouter API stream error: {e}") | |
| yield f"Sorry, I encountered an error: {error_msg}" | |
| except Exception as e: | |
| print(f"OpenRouter API stream error: {e}") | |
| yield f"Sorry, I encountered an error: {str(e)}" | |
| def _build_prompt( | |
| self, | |
| user_message: str, | |
| rag_context: Optional[str] = None, | |
| dashboard_context: Optional[Dict] = None | |
| ) -> str: | |
| """Build the full prompt with context.""" | |
| parts = [] | |
| # Add dashboard context if available | |
| if dashboard_context: | |
| context_str = self._format_dashboard_context(dashboard_context) | |
| if context_str: | |
| parts.append(f"**Current Dashboard State:**\n{context_str}\n") | |
| # Add RAG context if available | |
| if rag_context: | |
| parts.append(f"**Relevant Documentation:**\n{rag_context}\n") | |
| # Add the user's message | |
| parts.append(f"**User Question:**\n{user_message}") | |
| return "\n".join(parts) | |
| def _format_dashboard_context(self, context: Dict) -> str: | |
| """Format dashboard context for the prompt.""" | |
| lines = [] | |
| if context.get("model"): | |
| lines.append(f"- Selected Model: {context['model']}") | |
| if context.get("prompt"): | |
| lines.append(f"- Input Prompt: \"{context['prompt']}\"") | |
| if context.get("predicted_token"): | |
| prob = context.get("predicted_probability", 0) | |
| lines.append(f"- Predicted Next Token: \"{context['predicted_token']}\" (probability: {prob:.1%})") | |
| if context.get("top_predictions"): | |
| top = context["top_predictions"][:5] | |
| tokens_str = ", ".join([f"{t['token']} ({t['probability']:.1%})" for t in top]) | |
| lines.append(f"- Top Predictions: {tokens_str}") | |
| if context.get("ablated_heads"): | |
| heads_str = ", ".join([f"L{h['layer']}H{h['head']}" for h in context["ablated_heads"]]) | |
| lines.append(f"- Ablated Attention Heads: {heads_str}") | |
| return "\n".join(lines) | |
| def get_embedding(self, text: str) -> Optional[List[float]]: | |
| """ | |
| Get embedding vector for text using OpenRouter Embedding API. | |
| Args: | |
| text: Text to embed | |
| Returns: | |
| Embedding vector as list of floats, or None if failed | |
| """ | |
| if not self.is_available: | |
| return None | |
| try: | |
| response = requests.post( | |
| f"{OPENROUTER_BASE_URL}/embeddings", | |
| headers=self._headers, | |
| json={ | |
| "model": DEFAULT_EMBEDDING_MODEL, | |
| "input": text | |
| }, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["data"][0]["embedding"] | |
| except Exception as e: | |
| print(f"Embedding error: {e}") | |
| return None | |
| def get_query_embedding(self, query: str) -> Optional[List[float]]: | |
| """ | |
| Get embedding vector for a query. | |
| Note: OpenRouter doesn't have separate task types for embeddings, | |
| so this calls the same endpoint as get_embedding. | |
| Args: | |
| query: Query text to embed | |
| Returns: | |
| Embedding vector as list of floats, or None if failed | |
| """ | |
| return self.get_embedding(query) | |
| # Singleton instance | |
| _client_instance: Optional[OpenRouterClient] = None | |
| def get_openrouter_client() -> OpenRouterClient: | |
| """Get or create the singleton OpenRouter client instance.""" | |
| global _client_instance | |
| if _client_instance is None: | |
| _client_instance = OpenRouterClient() | |
| return _client_instance | |
| def generate_response( | |
| user_message: str, | |
| chat_history: Optional[List[Dict[str, str]]] = None, | |
| rag_context: Optional[str] = None, | |
| dashboard_context: Optional[Dict] = None | |
| ) -> str: | |
| """ | |
| Convenience function to generate a response. | |
| Args: | |
| user_message: The user's message | |
| chat_history: Previous chat messages | |
| rag_context: Retrieved RAG context | |
| dashboard_context: Current dashboard state | |
| Returns: | |
| Generated response text | |
| """ | |
| client = get_openrouter_client() | |
| return client.generate_response(user_message, chat_history, rag_context, dashboard_context) | |
| def generate_stream( | |
| user_message: str, | |
| chat_history: Optional[List[Dict[str, str]]] = None, | |
| rag_context: Optional[str] = None, | |
| dashboard_context: Optional[Dict] = None | |
| ): | |
| """ | |
| Convenience function to generate a streaming response. | |
| Args: | |
| user_message: The user's message | |
| chat_history: Previous chat messages | |
| rag_context: Retrieved RAG context | |
| dashboard_context: Current dashboard state | |
| Returns: | |
| Generator yielding text chunks | |
| """ | |
| client = get_openrouter_client() | |
| return client.generate_stream(user_message, chat_history, rag_context, dashboard_context) | |
| def get_embedding(text: str) -> Optional[List[float]]: | |
| """Convenience function to get document embedding.""" | |
| client = get_openrouter_client() | |
| return client.get_embedding(text) | |
| def get_query_embedding(query: str) -> Optional[List[float]]: | |
| """Convenience function to get query embedding.""" | |
| client = get_openrouter_client() | |
| return client.get_query_embedding(query) | |
| # Backward compatibility aliases (for gradual migration) | |
| GeminiClient = OpenRouterClient | |
| get_gemini_client = get_openrouter_client | |