Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import threading | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| # Hugging Face token | |
| hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
| torch.set_num_threads(4) | |
| # Globals | |
| tokenizer = None | |
| model = None | |
| current_model_name = None | |
| # Load selected model | |
| def load_model(model_name): | |
| global tokenizer, model, current_model_name | |
| full_model_name = f"MaxLSB/{model_name}" | |
| tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token) | |
| model.eval() | |
| current_model_name = model_name | |
| # Initialize default model | |
| load_model("LeCarnet-8M") | |
| # Streaming generation function | |
| def respond(message, max_tokens, temperature, top_p): | |
| inputs = tokenizer(message, return_tensors="pt") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| def run(): | |
| with torch.no_grad(): | |
| model.generate(**generate_kwargs) | |
| thread = threading.Thread(target=run) | |
| thread.start() | |
| for new_text in streamer: | |
| yield new_text | |
| # User input handler | |
| def user(message, chat_history): | |
| chat_history.append([message, None]) | |
| return "", chat_history | |
| # Bot response handler | |
| def bot(chat_history, max_tokens, temperature, top_p): | |
| # Insert model name bubble and placeholder for output | |
| chat_history.append([None, f"**{current_model_name}**"]) | |
| chat_history.append([None, ""]) | |
| # Render model name immediately | |
| yield chat_history | |
| # Stream generation into the last bubble | |
| message = chat_history[-3][0] # Original user message | |
| for chunk in respond(message, max_tokens, temperature, top_p): | |
| chat_history[-1][1] += chunk | |
| yield chat_history | |
| # Model selector handler | |
| def update_model(model_name): | |
| load_model(model_name) | |
| return [] | |
| # Gradio UI | |
| with gr.Blocks(css=".gr-chatbot .message { margin: 2px 0 !important; }", title=$1) as demo: | |
| with gr.Row(): | |
| gr.HTML(""" | |
| <div style="text-align: center; width: 100%;"> | |
| <h1 style="margin: 0;">LeCarnet Demo 📊</h1> | |
| </div> | |
| """ ) | |
| with gr.Row(): | |
| # Options column | |
| with gr.Column(scale=1, min_width=150): | |
| model_selector = gr.Dropdown( | |
| choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"], | |
| value="LeCarnet-8M", | |
| label="Select Model" | |
| ) | |
| max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling") | |
| clear_button = gr.Button("Clear Chat") | |
| # Chat column | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| bubble_full_width=False, | |
| height=500 | |
| ) | |
| msg_input = gr.Textbox( | |
| placeholder="Type your message and press Enter...", | |
| label="User Input" | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Il était une fois un petit renard nommé Roux. Roux aimait jouer dans la forêt."], | |
| ["Dans un petit village, il y avait un jardin magnifique."], | |
| ["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."], | |
| ], | |
| inputs=msg_input, | |
| label="Example Prompts" | |
| ) | |
| # Event handlers | |
| model_selector.change(fn=update_model, inputs=[model_selector], outputs=[]) | |
| msg_input.submit( | |
| fn=user, | |
| inputs=[msg_input, chatbot], | |
| outputs=[msg_input, chatbot], | |
| queue=False | |
| ).then( | |
| fn=bot, | |
| inputs=[chatbot, max_tokens, temperature, top_p], | |
| outputs=[chatbot] | |
| ) | |
| clear_button.click( | |
| fn=lambda: [], | |
| inputs=None, | |
| outputs=chatbot, | |
| queue=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) | |