Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import time | |
| import json | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import modelscope_studio.components.antd as antd | |
| import modelscope_studio.components.antdx as antdx | |
| import modelscope_studio.components.base as ms | |
| import modelscope_studio.components.pro as pro | |
| # Define model paths | |
| MODEL_PATHS = { | |
| "LeCarnet-3M": "MaxLSB/LeCarnet-3M", | |
| "LeCarnet-8M": "MaxLSB/LeCarnet-8M", | |
| "LeCarnet-21M": "MaxLSB/LeCarnet-21M", | |
| } | |
| # Set HF token | |
| hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
| # Load tokenizer and model globally | |
| tokenizer = None | |
| model = None | |
| def load_model(model_name: str): | |
| global tokenizer, model | |
| if model_name not in MODEL_PATHS: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| print(f"Loading {model_name}...") | |
| repo = MODEL_PATHS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token) | |
| model.eval() | |
| print(f"{model_name} loaded.") | |
| def generate_response(prompt, max_new_tokens=200): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response[len(prompt):].strip() | |
| # CSS for styling chatbot header with avatar | |
| css = """ | |
| .chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header { | |
| display: flex; | |
| align-items: center; | |
| } | |
| .chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header img { | |
| width: 20px; | |
| height: 20px; | |
| margin-right: 8px; | |
| vertical-align: middle; | |
| } | |
| """ | |
| # Default settings | |
| DEFAULT_SETTINGS = { | |
| "model": "LeCarnet-3M", | |
| "sys_prompt": "", | |
| } | |
| # Initial state with one fixed conversation | |
| state = gr.State({ | |
| "conversation_id": "default", | |
| "conversation_contexts": { | |
| "default": { | |
| "history": [], | |
| "settings": DEFAULT_SETTINGS, | |
| } | |
| }, | |
| }) | |
| # Welcome message (optional) | |
| def welcome_config(): | |
| return { | |
| "title": "LeCarnet Chatbot", | |
| "description": "Start chatting below!", | |
| "promptSuggestions": ["Hello", "Tell me a story", "How are you?"] | |
| } | |
| with gr.Blocks(css=css) as demo: | |
| with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"): | |
| # Right Column - Chat Interface | |
| with antd.Col(flex=1, elem_style=dict(height="100%")): | |
| with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"): | |
| chatbot = pro.Chatbot( | |
| elem_classes="chatbot-chat-messages", | |
| height=0, | |
| welcome_config=welcome_config() | |
| ) | |
| with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion: | |
| with ms.Slot("children"): | |
| input = antdx.Sender(placeholder="Type your message here...") | |
| current_state = state | |
| def add_message(user_input, state_value): | |
| history = state_value["conversation_contexts"]["default"]["history"] | |
| settings = state_value["conversation_contexts"]["default"]["settings"] | |
| selected_model = settings["model"] | |
| # Add user message | |
| history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())}) | |
| yield {"chatbot": gr.update(value=history)} | |
| # Start assistant response | |
| history.append({ | |
| "role": "assistant", | |
| "content": [], | |
| "key": str(uuid.uuid4()), | |
| "header": f'<img src="/file=media/le-carnet.png" style="width:20px;height:20px;margin-right:8px;"> <span>{selected_model}</span>', | |
| "loading": True | |
| }) | |
| yield {"chatbot": gr.update(value=history)} | |
| try: | |
| # Generate model response | |
| prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"]) | |
| response = generate_response(prompt) | |
| # Update assistant message | |
| history[-1]["content"] = [{"type": "text", "content": response}] | |
| history[-1]["loading"] = False | |
| yield {"chatbot": gr.update(value=history)} | |
| except Exception as e: | |
| history[-1]["content"] = [{ | |
| "type": "text", | |
| "content": f'<span style="color: red;">{str(e)}</span>' | |
| }] | |
| history[-1]["loading"] = False | |
| yield {"chatbot": gr.update(value=history)} | |
| input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot]) | |
| # Load default model on startup | |
| load_model(DEFAULT_SETTINGS["model"]) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=10).launch() |