| import json |
| import gradio as gr |
| import environs |
| import httpx |
| from typing import List, Tuple, Optional, Any |
| from dataclasses import dataclass |
| from enum import Enum |
| import os |
| from loguru import logger |
| import plotly.graph_objects as go |
|
|
| |
| env = environs.Env() |
| env.read_env() |
|
|
| |
| IS_HF_SPACE = os.environ.get("SPACE_ID") is not None |
| SPACE_URL = "https://lokumai-openai-openapi-template.hf.space" if IS_HF_SPACE else "http://localhost:7860" |
|
|
| |
| BASE_URL = env.str("BASE_URL", SPACE_URL) |
| API_KEY = env.str("API_KEY", "sk-test-xxx") |
| CHAT_API_ENDPOINT = f"{BASE_URL}/v1/chat/completions" |
|
|
| |
| STATIC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") |
| AVATAR_DIR = os.path.join(STATIC_DIR, "avatars") |
|
|
| |
| USER_AVATAR = os.path.join(AVATAR_DIR, "user.png") |
| BOT_AVATAR = os.path.join(AVATAR_DIR, "bot.png") |
|
|
| AUTH_USERNAME = env.str("AUTH_USERNAME", "admin") |
| AUTH_PASSWORD = env.str("AUTH_PASSWORD", "admin") |
|
|
|
|
| def app_auth(username: str, password: str) -> bool: |
| logger.debug(f"Entering app_auth: Username: {username}") |
| logger.debug(f"AUTH_USERNAME: {AUTH_USERNAME}") |
| return username == AUTH_USERNAME and password == AUTH_PASSWORD |
|
|
|
|
| |
| CUSTOM_CSS = """ |
| @font-face {{ |
| font-family: 'UI Sans Serif'; |
| src: url('/static/fonts/ui-sans-serif/ui-sans-serif-Regular.woff2') format('woff2'); |
| font-weight: normal; |
| font-style: normal; |
| }} |
| |
| @font-face {{ |
| font-family: 'UI Sans Serif'; |
| src: url('/static/fonts/ui-sans-serif/ui-sans-serif-Bold.woff2') format('woff2'); |
| font-weight: bold; |
| font-style: normal; |
| }} |
| |
| @font-face {{ |
| font-family: 'System UI'; |
| src: url('/static/fonts/system-ui/system-ui-Regular.woff2') format('woff2'); |
| font-weight: normal; |
| font-style: normal; |
| }} |
| |
| @font-face {{ |
| font-family: 'System UI'; |
| src: url('/static/fonts/system-ui/system-ui-Bold.woff2') format('woff2'); |
| font-weight: bold; |
| font-style: normal; |
| }} |
| |
| .gradio-container {{ |
| font-family: 'UI Sans Serif', 'System UI', sans-serif; |
| }} |
| |
| /* Improve chat interface */ |
| .chat-message {{ |
| padding: 1rem; |
| border-radius: 0.5rem; |
| margin-bottom: 1rem; |
| display: flex; |
| align-items: flex-start; |
| }} |
| |
| .chat-message.user {{ |
| background-color: #f3f4f6; |
| }} |
| |
| .chat-message.bot {{ |
| background-color: #eef2ff; |
| }} |
| |
| /* Improve button styles */ |
| button {{ |
| transition: all 0.2s ease-in-out; |
| }} |
| |
| button:hover {{ |
| transform: translateY(-1px); |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); |
| }} |
| |
| /* Improve input area */ |
| textarea {{ |
| border-radius: 0.5rem; |
| padding: 0.75rem; |
| border: 1px solid #e5e7eb; |
| transition: border-color 0.2s ease-in-out; |
| }} |
| |
| textarea:focus {{ |
| border-color: #4f46e5; |
| outline: none; |
| box-shadow: 0 0 0 2px rgba(79, 70, 229, 0.1); |
| }} |
| """ |
|
|
|
|
| class MessageStatus(Enum): |
| """Enum for message status""" |
|
|
| SUCCESS = "Success" |
| ERROR = "Error" |
|
|
|
|
| @dataclass |
| class ChatMessageResponse: |
| """Data class for message response""" |
|
|
| status: MessageStatus |
| content: str |
| figure: Optional[dict] = None |
| error: Optional[str] = None |
|
|
|
|
| class ChatAPI: |
| """Class to handle chat API interactions""" |
|
|
| def __init__(self, base_url: str, api_key: str): |
| self.base_url = base_url |
| self.api_key = api_key |
| self.endpoint = f"{base_url}/v1/chat/completions" |
|
|
| async def send_message(self, prompt: str) -> ChatMessageResponse: |
| """ |
| Send a message to the chat API |
| |
| Args: |
| prompt (str): The message to send |
| |
| Returns: |
| ChatMessageResponse: The response from the API |
| """ |
| logger.trace(f"Calling chat API with prompt: {prompt}") |
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| self.endpoint, |
| headers={"Authorization": f"Bearer {self.api_key}"}, |
| json={ |
| "messages": [{"role": "user", "content": prompt}], |
| "model": "gpt-3.5-turbo", |
| "completion_id": "new_chat", |
| "stream": True, |
| }, |
| timeout=30.0, |
| ) |
|
|
| if response.status_code != 200: |
| logger.error(f"API Error: {response.text}") |
| return ChatMessageResponse( |
| status=MessageStatus.ERROR, |
| content="", |
| figure=None, |
| error=f"API Error: {response.text}", |
| ) |
|
|
| result = response.json() |
| logger.trace("######################## BEGIN API response #########################") |
| logger.trace(json.dumps(result, indent=4)) |
| logger.trace("######################## END API response #########################") |
|
|
| if "choices" in result and len(result["choices"]) > 0: |
| message = result["choices"][0].get("message", {}) |
| figure = message.get("figure", None) |
| logger.trace(f"Figure: {figure}") |
| content = message.get("content", "Content not found") |
| logger.trace(f"Last message: {content}") |
| return ChatMessageResponse( |
| status=MessageStatus.SUCCESS, |
| content=content, |
| figure=figure, |
| ) |
| else: |
| logger.error("Invalid API response") |
| return ChatMessageResponse( |
| status=MessageStatus.ERROR, |
| content="", |
| error="Invalid API response", |
| ) |
|
|
| except httpx.TimeoutException: |
| logger.error("API request timed out") |
| return ChatMessageResponse( |
| status=MessageStatus.ERROR, |
| content="", |
| error="Request timed out. Please try again.", |
| ) |
| except Exception as e: |
| logger.error(f"Error: {str(e)}") |
| return ChatMessageResponse( |
| status=MessageStatus.ERROR, |
| content="", |
| error=f"Error: {str(e)}", |
| ) |
|
|
|
|
| class ChatInterface: |
| """Class to handle the Gradio chat interface""" |
|
|
| def __init__(self, chat_api: ChatAPI): |
| self.chat_api = chat_api |
| self.demo = self._build_interface() |
|
|
| def _build_interface(self) -> gr.Blocks: |
| """ |
| Build the Gradio interface |
| |
| Returns: |
| gr.Blocks: The Gradio interface |
| """ |
| with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: |
| |
| gr.Markdown(""" |
| # 🤖 Data Chatbot |
| |
| This chatbot allows you to chat with your data and visualize it. |
| Please enter your question in the text box below and click the "Send" button. |
| |
| > 📚 API Documentation: [https://lokumai-openai-openapi-template.hf.space/docs](https://lokumai-openai-openapi-template.hf.space/docs) |
| """) |
|
|
| |
| with gr.Row(): |
| with gr.Column(scale=4): |
| |
| chatbot = gr.Chatbot( |
| label="Chat History", |
| height=400, |
| show_copy_button=True, |
| avatar_images=(USER_AVATAR, BOT_AVATAR), |
| elem_classes=["chat-message"], |
| ) |
|
|
| |
| plot = gr.Plot(label="Data Visualization") |
|
|
| |
| with gr.Row(): |
| msg = gr.Textbox( |
| label="Your Message", |
| placeholder="Enter your question here...", |
| lines=3, |
| scale=4, |
| elem_classes=["message-input"], |
| ) |
| submit_btn = gr.Button("Send", variant="primary", scale=1) |
|
|
| |
| with gr.Row(): |
| clear_btn = gr.Button("Clear Chat", variant="secondary") |
| retry_btn = gr.Button("Retry", variant="secondary") |
|
|
| |
| status = gr.Textbox(label="Status", interactive=False) |
| last_message = gr.Textbox(label="Last Message", interactive=False) |
|
|
| |
| async def user_message(message: str, history: List[List[str]]) -> Tuple[List[List[str]], str, str, str, object]: |
| """Handle user message submission""" |
| if not message.strip(): |
| return history, "", "Please enter a message.", "", None |
|
|
| logger.debug(f"User message: {message}") |
|
|
| history.append([message, ""]) |
| response = await self.chat_api.send_message(message) |
|
|
| if response.status == MessageStatus.SUCCESS: |
| content = response.content |
| figure_data = response.figure |
| logger.trace(f"Figure data: {figure_data}") |
| figure = None |
| if isinstance(figure_data, dict): |
| logger.trace(f"Plotly input: {figure_data}") |
| try: |
| figure = go.Figure(figure_data) |
| logger.trace(f"Plotly figure: {figure.to_dict()}") |
| except Exception as e: |
| logger.error(f"Error creating plotly figure: {e}") |
| figure = None |
| history[-1][1] += "\n\n⚠️ Graph data is not valid, cannot be displayed." |
| history[-1][1] = content |
| return ( |
| history, |
| "", |
| "Message sent successfully.", |
| content, |
| figure, |
| ) |
| else: |
| history[-1][1] = f"❌ {response.error}" |
| return ( |
| history, |
| "", |
| f"Error: {response.error}", |
| "", |
| None, |
| ) |
|
|
| def clear_history() -> tuple[list[Any], str, str, str, None]: |
| """Clear chat history""" |
| return [], "", "Chat cleared.", "", None |
|
|
| def retry_last_message( |
| history: List[List[str]], |
| ) -> tuple[list[list[str]], str, str, str, None]: |
| """Retry the last message""" |
| if not history: |
| return history, "", "No message to retry.", "", None |
|
|
| last_message = history[-1][0] |
| return ( |
| history[:-1], |
| last_message, |
| "Last message will be retried.", |
| "", |
| None, |
| ) |
|
|
| |
| submit_btn.click( |
| fn=user_message, |
| inputs=[msg, chatbot], |
| outputs=[chatbot, msg, status, last_message, plot], |
| ) |
|
|
| msg.submit( |
| fn=user_message, |
| inputs=[msg, chatbot], |
| outputs=[chatbot, msg, status, last_message, plot], |
| ) |
|
|
| clear_btn.click( |
| fn=clear_history, |
| inputs=[], |
| outputs=[chatbot, msg, status, last_message, plot], |
| ) |
|
|
| retry_btn.click( |
| fn=retry_last_message, |
| inputs=[chatbot], |
| outputs=[chatbot, msg, status, last_message, plot], |
| ) |
|
|
| return demo |
|
|
|
|
| def build_gradio_app() -> gr.Blocks: |
| """ |
| Build and return the Gradio application |
| |
| Returns: |
| gr.Blocks: The Gradio interface |
| """ |
| chat_api = ChatAPI(BASE_URL, API_KEY) |
| chat_interface = ChatInterface(chat_api) |
| return chat_interface.demo |
|
|