Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, set_seed | |
| from functools import lru_cache | |
| # === 1. Cache the model loader once per session === | |
| def get_generator(model_name: str): | |
| return pipeline( | |
| "text-generation", | |
| model=model_name, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| # === 2. Chat callback === | |
| def chat(user_input, history, model_name, max_length, temperature, seed): | |
| # Set seed if provided | |
| if seed and seed > 0: | |
| set_seed(seed) | |
| # Lazy-load the model | |
| generator = get_generator(model_name) | |
| # Build prompt in Mistral’s instruction format | |
| prompt = ( | |
| "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n\n" | |
| f"{user_input}\n[/INST]" | |
| ) | |
| # Generate response | |
| outputs = generator( | |
| prompt, | |
| max_length=max_length, | |
| temperature=temperature, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| ) | |
| response = outputs[0]["generated_text"].split("[/INST]")[-1].strip() | |
| # Append to history as dicts for "messages" format | |
| history.append({"role": "user", "content": user_input}) | |
| history.append({"role": "assistant", "content": response}) | |
| return history, history | |
| # === 3. Build Gradio UI === | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🤖 Mistral-7B-Instruct Chatbot (Gradio)") | |
| # Chatbot and session-state | |
| chatbot = gr.Chatbot(type="messages") # :contentReference[oaicite:3]{index=3} | |
| state = gr.State([]) # :contentReference[oaicite:4]{index=4} | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| inp = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False) | |
| submit= gr.Button("Send") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| model_name = gr.Textbox(value="mistralai/Mistral-7B-Instruct-v0.3", label="Model name") | |
| max_length = gr.Slider(50, 1024, 256, step=50, label="Max tokens") | |
| temperature = gr.Slider(0.0, 1.0, 0.7, step=0.05, label="Temperature") | |
| seed = gr.Number(42, label="Random seed (0 disables)") | |
| # Wire the button: inputs include the gr.State; outputs update both Chatbot and state | |
| submit.click( | |
| fn=chat, | |
| inputs=[inp, state, model_name, max_length, temperature, seed], | |
| outputs=[chatbot, state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |