onechat / app.py
onedevelopment's picture
Set api_name=False to bypass Gradio 5 info schema crash and use type='messages'
2f40173
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
# Custom CSS for ChatGPT-like appearance
custom_css = """
body, .gradio-container {
background-color: #0d0d0d !important;
color: #e5e5e5 !important;
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important;
}
#chatbot {
border: none !important;
background: transparent !important;
}
.message.user {
background-color: #2f2f2f !important;
border-radius: 18px !important;
padding: 12px 16px !important;
margin: 8px 0 !important;
max-width: 85% !important;
align-self: flex-end !important;
}
.message.bot {
background-color: transparent !important;
padding: 12px 0 !important;
margin: 8px 0 !important;
max-width: 90% !important;
}
#input-container {
background: #1a1a1a !important;
border: 1px solid #333 !important;
border-radius: 12px !important;
padding: 8px !important;
margin-top: 20px !important;
}
#send-button {
background-color: #ffffff !important;
color: #000000 !important;
border-radius: 8px !important;
font-weight: 600 !important;
}
#sidebar {
background-color: #000000 !important;
border-right: 1px solid #222 !important;
padding: 20px !important;
}
.gr-button-secondary {
background-color: #222 !important;
color: white !important;
border: 1px solid #333 !important;
}
footer {
display: none !important;
}
"""
# Global cache for models
models_cache = {}
def get_pipeline(model_id):
if model_id not in models_cache:
print(f"Loading model {model_id}...")
try:
pipe = pipeline(
"text-generation",
model=model_id,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
models_cache[model_id] = pipe
except Exception as e:
raise gr.Error(f"Failed to load model {model_id} locally: {str(e)}")
return models_cache[model_id]
def respond(
message,
history,
model_id,
system_message,
max_tokens,
temperature,
top_p,
):
pipe = get_pipeline(model_id)
# Convert history to chat format for tokenizer
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg: messages.append({"role": "user", "content": user_msg})
if bot_msg: messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
# Generate using the pipeline
try:
# Prompt construction depends on model chat template
# Many small models use a specific format
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# We'll use the pipeline's built-in handling but for streaming we need to do it manually or use a ThreadedGenerator
# Since Gradio expects a generator for streaming, let's use the simplest streaming approach
outputs = pipe(
prompt,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=pipe.tokenizer.eos_token_id,
)
full_response = outputs[0]['generated_text']
# Extract only the newly generated part
response = full_response[len(prompt):]
yield response
except Exception as e:
yield f"Error during generation: {str(e)}"
with gr.Blocks(theme=gr.themes.Soft(primary_hue="gray"), css=custom_css) as demo:
with gr.Row():
# Sidebar for settings
with gr.Column(scale=1, elem_id="sidebar"):
gr.Markdown("## 🛠️ Settings")
model_id = gr.Dropdown(
choices=[
"onedevelopment/oneai-1.2-38m",
"onedevelopment/oneai-1-35m"
],
value="onedevelopment/oneai-1.2-38m",
label="Select Model",
interactive=True
)
system_message = gr.Textbox(
value="You are a helpful and advanced AI assistant named OneAI.",
label="System Prompt",
lines=3
)
with gr.Accordion("Advanced Parameters", open=False):
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p")
gr.Markdown("---")
gr.Markdown("Models run locally on Space CPU/GPU.")
# Main Chat Area
with gr.Column(scale=4):
gr.Markdown("# 💬 OneAI Chat")
chatbot = gr.Chatbot(
height=650,
elem_id="chatbot",
show_label=False,
bubble_full_width=False,
type="messages"
)
with gr.Row(elem_id="input-container"):
msg = gr.Textbox(
placeholder="Ask OneAI anything...",
show_label=False,
scale=9,
container=False
)
submit_btn = gr.Button("↑", scale=1, variant="primary", elem_id="send-button")
gr.ClearButton([msg, chatbot], variant="secondary")
# Linking components
def chat_echo(message, history):
history.append({"role": "user", "content": message})
return "", history
def bot_response(history, model_id, system_message, max_tokens, temperature, top_p):
user_message = history[-1]["content"]
legacy_history = []
for i in range(0, len(history) - 1, 2):
if i + 1 < len(history):
legacy_history.append([history[i]["content"], history[i+1]["content"]])
history.append({"role": "assistant", "content": ""})
response_gen = respond(
user_message,
legacy_history,
model_id,
system_message,
max_tokens,
temperature,
top_p
)
for partial_response in response_gen:
history[-1]["content"] = partial_response
yield history
msg.submit(chat_echo, [msg, chatbot], [msg, chatbot], queue=False, api_name=False).then(
bot_response, [chatbot, model_id, system_message, max_tokens, temperature, top_p], chatbot, api_name=False
)
submit_btn.click(chat_echo, [msg, chatbot], [msg, chatbot], queue=False, api_name=False).then(
bot_response, [chatbot, model_id, system_message, max_tokens, temperature, top_p], chatbot, api_name=False
)
if __name__ == "__main__":
demo.launch()