import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import torch import os # Custom CSS for ChatGPT-like appearance custom_css = """ body, .gradio-container { background-color: #0d0d0d !important; color: #e5e5e5 !important; font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; } #chatbot { border: none !important; background: transparent !important; } .message.user { background-color: #2f2f2f !important; border-radius: 18px !important; padding: 12px 16px !important; margin: 8px 0 !important; max-width: 85% !important; align-self: flex-end !important; } .message.bot { background-color: transparent !important; padding: 12px 0 !important; margin: 8px 0 !important; max-width: 90% !important; } #input-container { background: #1a1a1a !important; border: 1px solid #333 !important; border-radius: 12px !important; padding: 8px !important; margin-top: 20px !important; } #send-button { background-color: #ffffff !important; color: #000000 !important; border-radius: 8px !important; font-weight: 600 !important; } #sidebar { background-color: #000000 !important; border-right: 1px solid #222 !important; padding: 20px !important; } .gr-button-secondary { background-color: #222 !important; color: white !important; border: 1px solid #333 !important; } footer { display: none !important; } """ # Global cache for models models_cache = {} def get_pipeline(model_id): if model_id not in models_cache: print(f"Loading model {model_id}...") try: pipe = pipeline( "text-generation", model=model_id, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) models_cache[model_id] = pipe except Exception as e: raise gr.Error(f"Failed to load model {model_id} locally: {str(e)}") return models_cache[model_id] def respond( message, history, model_id, system_message, max_tokens, temperature, top_p, ): pipe = get_pipeline(model_id) # Convert history to chat format for tokenizer messages = [{"role": "system", "content": system_message}] for user_msg, bot_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if bot_msg: messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) # Generate using the pipeline try: # Prompt construction depends on model chat template # Many small models use a specific format prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # We'll use the pipeline's built-in handling but for streaming we need to do it manually or use a ThreadedGenerator # Since Gradio expects a generator for streaming, let's use the simplest streaming approach outputs = pipe( prompt, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, pad_token_id=pipe.tokenizer.eos_token_id, ) full_response = outputs[0]['generated_text'] # Extract only the newly generated part response = full_response[len(prompt):] yield response except Exception as e: yield f"Error during generation: {str(e)}" with gr.Blocks(theme=gr.themes.Soft(primary_hue="gray"), css=custom_css) as demo: with gr.Row(): # Sidebar for settings with gr.Column(scale=1, elem_id="sidebar"): gr.Markdown("## 🛠️ Settings") model_id = gr.Dropdown( choices=[ "onedevelopment/oneai-1.2-38m", "onedevelopment/oneai-1-35m" ], value="onedevelopment/oneai-1.2-38m", label="Select Model", interactive=True ) system_message = gr.Textbox( value="You are a helpful and advanced AI assistant named OneAI.", label="System Prompt", lines=3 ) with gr.Accordion("Advanced Parameters", open=False): max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p") gr.Markdown("---") gr.Markdown("Models run locally on Space CPU/GPU.") # Main Chat Area with gr.Column(scale=4): gr.Markdown("# 💬 OneAI Chat") chatbot = gr.Chatbot( height=650, elem_id="chatbot", show_label=False, bubble_full_width=False, type="messages" ) with gr.Row(elem_id="input-container"): msg = gr.Textbox( placeholder="Ask OneAI anything...", show_label=False, scale=9, container=False ) submit_btn = gr.Button("↑", scale=1, variant="primary", elem_id="send-button") gr.ClearButton([msg, chatbot], variant="secondary") # Linking components def chat_echo(message, history): history.append({"role": "user", "content": message}) return "", history def bot_response(history, model_id, system_message, max_tokens, temperature, top_p): user_message = history[-1]["content"] legacy_history = [] for i in range(0, len(history) - 1, 2): if i + 1 < len(history): legacy_history.append([history[i]["content"], history[i+1]["content"]]) history.append({"role": "assistant", "content": ""}) response_gen = respond( user_message, legacy_history, model_id, system_message, max_tokens, temperature, top_p ) for partial_response in response_gen: history[-1]["content"] = partial_response yield history msg.submit(chat_echo, [msg, chatbot], [msg, chatbot], queue=False, api_name=False).then( bot_response, [chatbot, model_id, system_message, max_tokens, temperature, top_p], chatbot, api_name=False ) submit_btn.click(chat_echo, [msg, chatbot], [msg, chatbot], queue=False, api_name=False).then( bot_response, [chatbot, model_id, system_message, max_tokens, temperature, top_p], chatbot, api_name=False ) if __name__ == "__main__": demo.launch()