Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer | |
| class MultiModelChat: | |
| def __init__(self): | |
| self.models = {} | |
| def ensure_model_loaded(self, model_name): | |
| """Lazy load a model only when needed""" | |
| if model_name not in self.models: | |
| print(f"Loading {model_name} model...") | |
| if model_name == 'SmolLM2': | |
| self.models['SmolLM2'] = { | |
| 'tokenizer': AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct"), | |
| 'model': AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct") | |
| } | |
| elif model_name == 'FLAN-T5': | |
| self.models['FLAN-T5'] = { | |
| 'tokenizer': T5Tokenizer.from_pretrained("google/flan-t5-small"), | |
| 'model': T5ForConditionalGeneration.from_pretrained("google/flan-t5-small") | |
| } | |
| # Set pad token for the newly loaded model | |
| if self.models[model_name]['tokenizer'].pad_token is None: | |
| self.models[model_name]['tokenizer'].pad_token = self.models[model_name]['tokenizer'].eos_token | |
| print(f"{model_name} model loaded successfully!") | |
| def chat(self, message, history, model_choice): | |
| if model_choice == "SmolLM2": | |
| return self.chat_smol(message, history) | |
| elif model_choice == "FLAN-T5": | |
| return self.chat_flan(message, history) | |
| def chat_smol(self, message, history): | |
| self.ensure_model_loaded('SmolLM2') | |
| tokenizer = self.models['SmolLM2']['tokenizer'] | |
| model = self.models['SmolLM2']['model'] | |
| inputs = tokenizer(f"User: {message}\nAssistant:", return_tensors="pt") | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=80, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response.split("Assistant:")[-1].strip() | |
| def chat_flan(self, message, history): | |
| self.ensure_model_loaded('FLAN-T5') | |
| tokenizer = self.models['FLAN-T5']['tokenizer'] | |
| model = self.models['FLAN-T5']['model'] | |
| inputs = tokenizer(f"Answer the question: {message}", return_tensors="pt") | |
| outputs = model.generate(inputs.input_ids, max_length=100) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| chat_app = MultiModelChat() | |
| def respond(message, history, model_choice): | |
| return chat_app.chat(message, history, model_choice) | |
| with gr.Blocks(theme="soft") as demo: | |
| gr.Markdown("# Multi-Model Tiny Chatbot") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=["SmolLM2", "FLAN-T5"], | |
| value="SmolLM2", | |
| label="Select Model" | |
| ) | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(label="Message", placeholder="Type your message here...") | |
| clear = gr.Button("Clear") | |
| def user_message(message, history): | |
| return "", history + [[message, None]] | |
| def bot_message(history, model_choice): | |
| user_msg = history[-1][0] | |
| bot_response = chat_app.chat(user_msg, history[:-1], model_choice) | |
| history[-1][1] = bot_response | |
| return history | |
| msg.submit(user_message, [msg, chatbot], [msg, chatbot]).then( | |
| bot_message, [chatbot, model_dropdown], chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.launch() |