Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import threading | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| # Hugging Face token | |
| hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
| torch.set_num_threads(1) | |
| # Globals | |
| tokenizer = None | |
| model = None | |
| current_model_name = None | |
| # Load selected model | |
| def load_model(model_name): | |
| global tokenizer, model, current_model_name | |
| # Only load if it's a different model | |
| if current_model_name == model_name: | |
| return | |
| full_model_name = f"MaxLSB/{model_name}" | |
| print(f"Loading model: {full_model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token) | |
| model.eval() | |
| current_model_name = model_name | |
| print(f"Model loaded: {current_model_name}") | |
| # Initialize default model | |
| load_model("LeCarnet-8M") | |
| # Streaming generation function | |
| def respond(message, max_tokens, temperature, top_p, selected_model): | |
| # Ensure the correct model is loaded before generation | |
| load_model(selected_model) | |
| inputs = tokenizer(message, return_tensors="pt") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| def run(): | |
| with torch.no_grad(): | |
| model.generate(**generate_kwargs) | |
| thread = threading.Thread(target=run) | |
| thread.start() | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| yield f"**{current_model_name}**\n\n{response}" | |
| # User input handler | |
| def user(message, chat_history): | |
| chat_history.append([message, None]) | |
| return "", chat_history | |
| # Bot response handler - UPDATED to pass selected model | |
| def bot(chatbot, max_tokens, temperature, top_p, selected_model): | |
| message = chatbot[-1][0] | |
| response_generator = respond(message, max_tokens, temperature, top_p, selected_model) | |
| for response in response_generator: | |
| chatbot[-1][1] = response | |
| yield chatbot | |
| # Model selector handler | |
| def update_model(model_name): | |
| load_model(model_name) | |
| return model_name | |
| # Clear chat handler | |
| def clear_chat(): | |
| return None | |
| # Gradio UI | |
| with gr.Blocks(title="LeCarnet - Chat Interface") as demo: | |
| with gr.Row(): | |
| gr.HTML(""" | |
| <div style="text-align: center; width: 100%;"> | |
| <h1 style="margin: 0;">LeCarnet Demo</h1> | |
| </div> | |
| """) | |
| msg_input = gr.Textbox( | |
| placeholder="Il était une fois un petit garçon", | |
| label="User Input", | |
| render=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=150): | |
| model_selector = gr.Dropdown( | |
| choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"], | |
| value="LeCarnet-8M", | |
| label="Select Model" | |
| ) | |
| max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.4, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling") | |
| clear_button = gr.Button("Clear Chat") | |
| gr.Examples( | |
| examples=[ | |
| ["Il était une fois un petit phoque nommé Zoom. Zoom était très habile et aimait jouer dans l'eau."], | |
| ["Il était une fois un petit écureuil nommé Pipo. Pipo adorait grimper aux arbres."], | |
| ["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."], | |
| ], | |
| inputs=msg_input, | |
| label="Example Prompts" | |
| ) | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| bubble_full_width=False, | |
| height=500 | |
| ) | |
| msg_input.render() | |
| # Event Handlers | |
| model_selector.change( | |
| fn=update_model, | |
| inputs=[model_selector], | |
| outputs=[model_selector], | |
| ) | |
| msg_input.submit( | |
| fn=user, | |
| inputs=[msg_input, chatbot], | |
| outputs=[msg_input, chatbot], | |
| queue=False | |
| ).then( | |
| fn=bot, | |
| inputs=[chatbot, max_tokens, temperature, top_p, model_selector], # Pass model_selector | |
| outputs=[chatbot] | |
| ) | |
| clear_button.click( | |
| fn=clear_chat, | |
| inputs=None, | |
| outputs=chatbot, | |
| queue=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) |