|
|
import gradio as gr |
|
|
import asyncio |
|
|
import json |
|
|
import logging |
|
|
from typing import List, Dict, Any |
|
|
from .models import ChatMessage, ChatRequest |
|
|
from .llm_manager import LLMManager |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GradioChatInterface: |
|
|
"""Gradio interface for chat completion.""" |
|
|
|
|
|
def __init__(self, llm_manager: LLMManager): |
|
|
self.llm_manager = llm_manager |
|
|
self.chat_history: List[Dict[str, str]] = [] |
|
|
|
|
|
def create_interface(self): |
|
|
"""Create the Gradio interface.""" |
|
|
|
|
|
|
|
|
css = """ |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
margin: auto !important; |
|
|
} |
|
|
.chat-container { |
|
|
height: 600px; |
|
|
overflow-y: auto; |
|
|
border: 1px solid #e0e0e0; |
|
|
border-radius: 8px; |
|
|
padding: 20px; |
|
|
background-color: #fafafa; |
|
|
} |
|
|
.user-message { |
|
|
background-color: #007bff; |
|
|
color: white; |
|
|
padding: 10px 15px; |
|
|
border-radius: 18px; |
|
|
margin: 10px 0; |
|
|
max-width: 80%; |
|
|
margin-left: auto; |
|
|
text-align: right; |
|
|
} |
|
|
.assistant-message { |
|
|
background-color: #e9ecef; |
|
|
color: #333; |
|
|
padding: 10px 15px; |
|
|
border-radius: 18px; |
|
|
margin: 10px 0; |
|
|
max-width: 80%; |
|
|
margin-right: auto; |
|
|
} |
|
|
.system-message { |
|
|
background-color: #ffc107; |
|
|
color: #333; |
|
|
padding: 10px 15px; |
|
|
border-radius: 18px; |
|
|
margin: 10px 0; |
|
|
max-width: 80%; |
|
|
margin-right: auto; |
|
|
font-style: italic; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, title="LLM Chat Interface") as interface: |
|
|
gr.Markdown("# 🤖 LLM Chat Interface") |
|
|
gr.Markdown( |
|
|
"Chat with your local LLM model using a beautiful web interface." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
|
|
|
chat_display = gr.HTML( |
|
|
value="<div class='chat-container'><p>Start a conversation by typing a message below!</p></div>", |
|
|
label="Chat History", |
|
|
elem_classes=["chat-container"], |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
message_input = gr.Textbox( |
|
|
placeholder="Type your message here...", |
|
|
label="Message", |
|
|
lines=3, |
|
|
scale=4, |
|
|
) |
|
|
send_btn = gr.Button("Send", variant="primary", scale=1) |
|
|
|
|
|
|
|
|
clear_btn = gr.Button("Clear Chat", variant="secondary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.Markdown("### ⚙️ Model Settings") |
|
|
|
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=["microsoft/phi-1_5"], |
|
|
value="microsoft/phi-1_5", |
|
|
label="Model", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
temperature_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Temperature", |
|
|
info="Controls randomness (0 = deterministic, 2 = very random)", |
|
|
) |
|
|
|
|
|
top_p_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.1, |
|
|
label="Top-p", |
|
|
info="Controls diversity via nucleus sampling", |
|
|
) |
|
|
|
|
|
max_tokens_slider = gr.Slider( |
|
|
minimum=50, |
|
|
maximum=2048, |
|
|
value=512, |
|
|
step=50, |
|
|
label="Max Tokens", |
|
|
info="Maximum number of tokens to generate", |
|
|
) |
|
|
|
|
|
|
|
|
system_message = gr.Textbox( |
|
|
placeholder="You are a helpful AI assistant.", |
|
|
label="System Message", |
|
|
lines=3, |
|
|
info="Optional system message to set the assistant's behavior", |
|
|
) |
|
|
|
|
|
|
|
|
model_status = gr.Markdown( |
|
|
f"**Model Status:** {'✅ Loaded' if self.llm_manager.is_loaded else '❌ Not Loaded'}\n" |
|
|
f"**Model Type:** {self.llm_manager.model_type}" |
|
|
) |
|
|
|
|
|
|
|
|
send_btn.click( |
|
|
fn=self.send_message, |
|
|
inputs=[ |
|
|
message_input, |
|
|
system_message, |
|
|
temperature_slider, |
|
|
top_p_slider, |
|
|
max_tokens_slider, |
|
|
chat_display, |
|
|
], |
|
|
outputs=[chat_display, message_input], |
|
|
) |
|
|
|
|
|
message_input.submit( |
|
|
fn=self.send_message, |
|
|
inputs=[ |
|
|
message_input, |
|
|
system_message, |
|
|
temperature_slider, |
|
|
top_p_slider, |
|
|
max_tokens_slider, |
|
|
chat_display, |
|
|
], |
|
|
outputs=[chat_display, message_input], |
|
|
) |
|
|
|
|
|
clear_btn.click(fn=self.clear_chat, outputs=[chat_display]) |
|
|
|
|
|
|
|
|
interface.load(fn=self.update_model_status, outputs=[model_status]) |
|
|
|
|
|
return interface |
|
|
|
|
|
def format_chat_html(self, messages: List[Dict[str, str]]) -> str: |
|
|
"""Format chat messages as HTML.""" |
|
|
html_parts = ['<div class="chat-container">'] |
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get("role", "user") |
|
|
content = msg.get("content", "") |
|
|
|
|
|
if role == "user": |
|
|
html_parts.append(f'<div class="user-message">{content}</div>') |
|
|
elif role == "assistant": |
|
|
html_parts.append(f'<div class="assistant-message">{content}</div>') |
|
|
elif role == "system": |
|
|
html_parts.append( |
|
|
f'<div class="system-message">System: {content}</div>' |
|
|
) |
|
|
|
|
|
html_parts.append("</div>") |
|
|
return "".join(html_parts) |
|
|
|
|
|
def send_message( |
|
|
self, |
|
|
message: str, |
|
|
system_msg: str, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
max_tokens: int, |
|
|
current_display: str, |
|
|
) -> tuple[str, str]: |
|
|
"""Send a message and get response.""" |
|
|
if not message.strip(): |
|
|
return current_display, "" |
|
|
|
|
|
try: |
|
|
|
|
|
self.chat_history.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
messages = [] |
|
|
|
|
|
|
|
|
if system_msg.strip(): |
|
|
messages.append(ChatMessage(role="system", content=system_msg.strip())) |
|
|
|
|
|
|
|
|
for msg in self.chat_history: |
|
|
messages.append(ChatMessage(role=msg["role"], content=msg["content"])) |
|
|
|
|
|
|
|
|
request = ChatRequest( |
|
|
messages=messages, |
|
|
model="llama-2-7b-chat", |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
stream=False, |
|
|
) |
|
|
|
|
|
|
|
|
response = asyncio.run(self.llm_manager.generate(request)) |
|
|
|
|
|
|
|
|
if response.get("choices") and len(response["choices"]) > 0: |
|
|
assistant_content = response["choices"][0]["message"]["content"] |
|
|
self.chat_history.append( |
|
|
{"role": "assistant", "content": assistant_content} |
|
|
) |
|
|
else: |
|
|
assistant_content = "Sorry, I couldn't generate a response." |
|
|
self.chat_history.append( |
|
|
{"role": "assistant", "content": assistant_content} |
|
|
) |
|
|
|
|
|
|
|
|
updated_display = self.format_chat_html(self.chat_history) |
|
|
|
|
|
return updated_display, "" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in send_message: {e}") |
|
|
error_msg = f"Error: {str(e)}" |
|
|
self.chat_history.append({"role": "assistant", "content": error_msg}) |
|
|
updated_display = self.format_chat_html(self.chat_history) |
|
|
return updated_display, "" |
|
|
|
|
|
def clear_chat(self) -> str: |
|
|
"""Clear the chat history.""" |
|
|
self.chat_history = [] |
|
|
return "<div class='chat-container'><p>Chat cleared. Start a new conversation!</p></div>" |
|
|
|
|
|
def update_model_status(self) -> str: |
|
|
"""Update the model status display.""" |
|
|
return ( |
|
|
f"**Model Status:** {'✅ Loaded' if self.llm_manager.is_loaded else '❌ Not Loaded'}\n" |
|
|
f"**Model Type:** {self.llm_manager.model_type}\n" |
|
|
f"**Context Window:** {self.llm_manager.context_window} tokens" |
|
|
) |
|
|
|
|
|
|
|
|
def create_gradio_app(llm_manager: LLMManager = None): |
|
|
"""Create and launch the Gradio app.""" |
|
|
if llm_manager is None: |
|
|
|
|
|
llm_manager = LLMManager() |
|
|
asyncio.run(llm_manager.load_model()) |
|
|
|
|
|
interface = GradioChatInterface(llm_manager) |
|
|
gradio_interface = interface.create_interface() |
|
|
|
|
|
return gradio_interface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import asyncio |
|
|
|
|
|
async def main(): |
|
|
llm_manager = LLMManager() |
|
|
await llm_manager.load_model() |
|
|
|
|
|
interface = create_gradio_app(llm_manager) |
|
|
interface.launch( |
|
|
server_name="0.0.0.0", server_port=7860, share=False, debug=True |
|
|
) |
|
|
|
|
|
asyncio.run(main()) |
|
|
|