Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import httpx | |
| import json | |
| import os | |
| import numpy as np | |
| import torch | |
| import asyncio | |
| import logging | |
| # =========================== | |
| # Logging Configuration | |
| # =========================== | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # =========================== | |
| # Configuration and Constants | |
| # =========================== | |
| BASE_URL = os.getenv("LMSTUDIO_API_BASE_URL", "http://localhost:1234/v1") | |
| USE_GPU = torch.cuda.is_available() | |
| DEVICE = torch.device("cuda" if USE_GPU else "cpu") | |
| logger.info(f"GPU Available: {USE_GPU}, Device: {DEVICE}") | |
| MODEL_MAX_TOKENS = 32768 | |
| AVERAGE_CHARS_PER_TOKEN = 4 | |
| BUFFER_TOKENS = 1500 | |
| MIN_OUTPUT_TOKENS = 500 | |
| MAX_EMBEDDINGS = 100 | |
| HTTPX_TIMEOUT = 3000 | |
| client = httpx.AsyncClient(timeout=HTTPX_TIMEOUT) | |
| # =========================== | |
| # Utility Functions | |
| # =========================== | |
| def calculate_max_tokens(message_history, model_max_tokens=MODEL_MAX_TOKENS, | |
| buffer=BUFFER_TOKENS, avg_chars_per_token=AVERAGE_CHARS_PER_TOKEN, | |
| min_tokens=MIN_OUTPUT_TOKENS): | |
| total_length = sum(len(message["content"]) for message in message_history) | |
| input_tokens = total_length / avg_chars_per_token | |
| max_tokens = model_max_tokens - int(input_tokens) - buffer | |
| calculated_max = max(max_tokens, min_tokens) | |
| logger.info(f"Calculated max tokens: {calculated_max}") | |
| return calculated_max | |
| async def get_embeddings(text): | |
| url = f"{BASE_URL}/embeddings" | |
| payload = {"model": "nomic_embed_text_v1_5_f16.gguf", "input": text} | |
| try: | |
| response = await client.post(url, json=payload, headers={"Content-Type": "application/json"}) | |
| response.raise_for_status() | |
| data = response.json() | |
| if "data" in data and len(data["data"]) > 0: | |
| embedding = np.array(data["data"][0]["embedding"]) | |
| if USE_GPU: | |
| embedding = torch.tensor(embedding, device=DEVICE).tolist() | |
| logger.info("Successfully retrieved embeddings.") | |
| return embedding | |
| except (httpx.RequestError, httpx.HTTPStatusError, json.JSONDecodeError) as e: | |
| logger.error(f"Error occurred while getting embeddings: {e}") | |
| return None | |
| def calculate_similarity(vec1, vec2): | |
| if vec1 is None or vec2 is None: | |
| logger.warning("One or both vectors are None. Returning similarity as 0.0.") | |
| return 0.0 | |
| vec1_tensor = torch.tensor(vec1, device=DEVICE) if not isinstance(vec1, torch.Tensor) else vec1.to(DEVICE) | |
| vec2_tensor = torch.tensor(vec2, device=DEVICE) if not isinstance(vec2, torch.Tensor) else vec2.to(DEVICE) | |
| similarity = torch.nn.functional.cosine_similarity(vec1_tensor.unsqueeze(0), vec2_tensor.unsqueeze(0)).item() | |
| logger.info(f"Calculated similarity: {similarity}") | |
| return similarity | |
| async def chat_with_lmstudio(messages, max_tokens): | |
| url = f"{BASE_URL}/chat/completions" | |
| payload = { | |
| "model": "Qwen2.5-Coder-32B-Instruct-IQ2_M.gguf", | |
| "messages": messages, | |
| "temperature": 1, | |
| "max_tokens": max_tokens, | |
| "stream": True, | |
| } | |
| try: | |
| logger.info("Sending chat completion request to LM Studio API.") | |
| async with client.stream("POST", url, json=payload, headers={"Content-Type": "application/json"}) as response: | |
| response.raise_for_status() | |
| async for line in response.aiter_lines(): | |
| if line: | |
| try: | |
| decoded_line = line.strip() | |
| if decoded_line.startswith("data: "): | |
| data = json.loads(decoded_line[6:]) | |
| content = data.get("choices", [{}])[0].get("delta", {}).get("content", "") | |
| if content: | |
| yield content | |
| except json.JSONDecodeError: | |
| continue | |
| except (httpx.RequestError, httpx.HTTPStatusError) as e: | |
| logger.error(f"Error occurred while streaming chat completion: {e}") | |
| yield "An error occurred while generating a response." | |
| # =========================== | |
| # Gradio Interface with Dynamic Resizing | |
| # =========================== | |
| def gradio_chat_interface(): | |
| css = """ | |
| .gradio-container { | |
| background-color: #1e1e1e; | |
| color: #f0f0f0; | |
| font-family: 'Arial', sans-serif; | |
| } | |
| .gr-button { | |
| background-color: #6200ea; | |
| color: white; | |
| font-weight: bold; | |
| } | |
| .gr-textbox { | |
| border: 2px solid #6200ea; | |
| resize: both; /* Allow resizing */ | |
| } | |
| .gr-chat-message { | |
| border-radius: 8px; | |
| padding: 10px; | |
| } | |
| """ | |
| js = """ | |
| function resizeTextarea(event) { | |
| const textarea = event.data[0].querySelector('textarea'); | |
| if (textarea) { | |
| textarea.style.height = 'auto'; | |
| textarea.style.height = textarea.scrollHeight + 'px'; | |
| } | |
| } | |
| // Trigger resize on input change | |
| document.addEventListener('input', function(event) { | |
| if (event.target.classList.contains('gr-input')) { | |
| resizeTextarea([event.target]); | |
| } | |
| }); | |
| // Trigger resize on response updates | |
| const chatbot = document.querySelector('.gradio-container .gradio-chatbot'); | |
| if (chatbot) { | |
| const observer = new MutationObserver((mutationsList) => { | |
| mutationsList.forEach((mutation) => { | |
| if (mutation.type === 'childList') { | |
| mutation.addedNodes.forEach((node) => { | |
| if (node.classList && node.classList.contains('gr-chat-message')) { | |
| resizeTextarea([node]); | |
| } | |
| }); | |
| } | |
| }); | |
| }); | |
| observer.observe(chatbot, { childList: true, subtree: true }); | |
| } | |
| return [resizeTextarea]; | |
| """ | |
| with gr.Blocks(css=css, theme="default") as interface: | |
| gr.Markdown("# 🌟 **Enhanced Chat Interface**\nBeautiful and functional AI-powered chat.") | |
| chatbot = gr.Chatbot(label="Conversation", type="messages") | |
| user_input = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=1, # Start with a smaller number of lines | |
| interactive=True, | |
| container=False, # Avoid additional padding | |
| ) | |
| send_button = gr.Button("Send", elem_id="send_button") | |
| context_display = gr.Textbox( | |
| label="Relevant Context", | |
| interactive=False, | |
| elem_id="context_display" | |
| ) | |
| embeddings_state = gr.State({"embeddings": [], "messages_history": []}) | |
| async def chat_handler(message, state): | |
| embeddings = state.get("embeddings", []) | |
| messages_history = state.get("messages_history", []) | |
| user_embedding = await get_embeddings(message) | |
| if not user_embedding: | |
| yield [[], state, "Failed to generate embeddings."] | |
| return | |
| embeddings.append(user_embedding) | |
| messages_history.append({"role": "user", "content": message}) | |
| if len(embeddings) > MAX_EMBEDDINGS: | |
| embeddings = embeddings[-MAX_EMBEDDINGS:] | |
| messages_history = messages_history[-MAX_EMBEDDINGS:] | |
| max_tokens = calculate_max_tokens(messages_history) | |
| response = "" | |
| async for chunk in chat_with_lmstudio(messages_history, max_tokens): | |
| response += chunk | |
| updated_chat = chatbot.value.copy() | |
| updated_chat.append({"role": "user", "content": message}) | |
| updated_chat.append({"role": "assistant", "content": response}) | |
| # Update the context display with some relevant context logic (placeholder for now) | |
| context_display_text = f"Context: {message}" | |
| yield [updated_chat, {"embeddings": embeddings, "messages_history": messages_history}, context_display_text] | |
| send_button.click( | |
| chat_handler, | |
| inputs=[user_input, embeddings_state], | |
| outputs=[chatbot, embeddings_state, context_display], | |
| show_progress=True | |
| ) | |
| interface.launch(share=True, server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| asyncio.run(gradio_chat_interface()) |