Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import traceback | |
| # Try to import peft, if not available use base model only | |
| try: | |
| from peft import PeftModel | |
| PEFT_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: peft library not found. LoRA adapters will not be available.") | |
| PEFT_AVAILABLE = False | |
| # === Define all your available models here === | |
| # This new dictionary allows you to define both base models and LoRA adapters. | |
| # 'type': can be 'base' for a standalone model or 'lora' for an adapter. | |
| # 'id': the Hugging Face model/adapter ID. | |
| # 'base_model_id': for LoRA adapters, specifies which base model to use. | |
| AVAILABLE_MODELS = { | |
| "BokantLM0.1-0.5B": { | |
| "type": "base", | |
| "id": "llaa33219/BokantLM0.1-0.5B", | |
| }, | |
| "BokantLM0.1-135M-Deepseek": { | |
| "type": "base", | |
| "id": "llaa33219/BokantLM0.1-135M-Deepseek", | |
| }, | |
| "BokantLM0.1-135M-claude-3.7-sonnet": { | |
| "type": "lora", | |
| "id": "llaa33219/BokantLM0.1-135M-SmolLM2-135M-LoRA-claude-3.7-sonnet", | |
| "base_model_id": "HuggingFaceTB/SmolLM2-135M" | |
| }, | |
| "Vere1Ko-360M": { | |
| "type": "base", | |
| "id": "llaa33219/Vere1Ko-360M" | |
| }, | |
| "Vere1Ko-0.6B": { | |
| "type": "base", | |
| "id": "llaa33219/Vere1Ko-0.6B" | |
| }, | |
| "Solar-Open-100B-Tmesis": { | |
| "type": "base", | |
| "id": "llaa33219/Solar-Open-100B-pruned-20pct" | |
| }, | |
| # --- You can add more models here --- | |
| # Example of another base model: | |
| # "Another Base Model (e.g., Ko-LLaMA)": { | |
| # "type": "base", | |
| # "id": "beomi/KoAlpaca-Polyglot-5.8B" | |
| # }, | |
| # Example of another LoRA adapter: | |
| # "Another LoRA Finetune": { | |
| # "type": "lora", | |
| # "id": "path/to/your/other-lora-adapter", | |
| # "base_model_id": "Qwen/Qwen2.5-3B-Instruct" | |
| # }, | |
| } | |
| # Global variables for model caching | |
| current_model_name = None | |
| current_tokenizer = None | |
| current_model = None | |
| def load_model(name): | |
| """ | |
| Loads a model based on the selection. It can load a base model directly | |
| or load a base model and then apply a LoRA adapter to it. | |
| """ | |
| global current_model_name, current_tokenizer, current_model | |
| if current_model_name == name: | |
| # Model is already loaded, no need to do anything | |
| return current_tokenizer, current_model | |
| print(f"Switching to model: {name}") | |
| # Clear previous model from memory | |
| if current_model is not None: | |
| del current_model | |
| del current_tokenizer | |
| current_model = None | |
| current_tokenizer = None | |
| torch.cuda.empty_cache() | |
| print("Cleared previous model from memory.") | |
| try: | |
| model_info = AVAILABLE_MODELS[name] | |
| model_type = model_info["type"] | |
| model_id = model_info["id"] | |
| # --- Case 1: Load a LoRA adapter model --- | |
| if model_type == 'lora' and PEFT_AVAILABLE: | |
| base_model_id = model_info["base_model_id"] | |
| adapter_id = model_id | |
| print(f"Loading LoRA model. Base: '{base_model_id}', Adapter: '{adapter_id}'") | |
| # Load tokenizer from the adapter (it might have special tokens) | |
| current_tokenizer = AutoTokenizer.from_pretrained(adapter_id, trust_remote_code=True) | |
| # Load base model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Resize token embeddings if the adapter's vocab differs from the base model's | |
| if base_model.config.vocab_size != len(current_tokenizer): | |
| print(f"Resizing token embeddings from {base_model.config.vocab_size} to {len(current_tokenizer)}") | |
| base_model.resize_token_embeddings(len(current_tokenizer)) | |
| # Load and merge the LoRA adapter | |
| print(f"Loading and merging LoRA adapter: {adapter_id}") | |
| lora_model = PeftModel.from_pretrained( | |
| base_model, | |
| adapter_id, | |
| torch_dtype=torch.float16 | |
| ) | |
| current_model = lora_model.merge_and_unload() | |
| print("Successfully merged LoRA adapter.") | |
| # --- Case 2: Load a base model directly --- | |
| else: | |
| if model_type == 'lora' and not PEFT_AVAILABLE: | |
| print(f"PEFT not available. Cannot load LoRA adapter '{name}'. Falling back to its base model.") | |
| # Fallback to the base model if PEFT is missing | |
| model_id = model_info.get("base_model_id", list(AVAILABLE_MODELS.values())[0]['id']) | |
| print(f"Loading base model: {model_id}") | |
| current_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| current_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Common post-processing for any loaded model | |
| if current_tokenizer.pad_token is None: | |
| current_tokenizer.pad_token = current_tokenizer.eos_token | |
| print("Set pad_token to eos_token.") | |
| current_model_name = name | |
| print(f"โ Successfully loaded model: {name}") | |
| except Exception as e: | |
| print(f"โ Failed to load model {name}: {e}") | |
| traceback.print_exc() | |
| # Clean up on failure | |
| current_model_name = None | |
| current_model = None | |
| current_tokenizer = None | |
| raise e # Re-raise the exception to be caught by the chat function | |
| return current_tokenizer, current_model | |
| def chat_fn(message, history, selected_model): | |
| try: | |
| tokenizer, model = load_model(selected_model) | |
| # Ensure model is on the correct device (GPU) | |
| if not next(model.parameters()).is_cuda: | |
| model = model.cuda() | |
| # Build conversation history for the chat template | |
| conversation = [] | |
| for user_msg, bot_msg in history: | |
| conversation.append({"role": "user", "content": user_msg}) | |
| conversation.append({"role": "assistant", "content": bot_msg}) | |
| conversation.append({"role": "user", "content": message}) | |
| # Apply the model's specific chat template | |
| try: | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation=conversation, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).cuda() | |
| except Exception as e: | |
| print(f"Chat template error: {e}. Falling back to simple encoding.") | |
| text = f"User: {message}\nAssistant:" | |
| input_ids = tokenizer.encode(text, return_tensors="pt").cuda() | |
| # Generate response | |
| with torch.no_grad(): | |
| # Create attention mask | |
| attention_mask = torch.ones_like(input_ids) | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=4096, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| use_cache=True, | |
| attention_mask=attention_mask | |
| ) | |
| # Decode the generated tokens into text, skipping the prompt | |
| response = tokenizer.decode( | |
| output_ids[0][input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| return response | |
| except Exception as e: | |
| print(f"Error in chat_fn: {str(e)}") | |
| traceback.print_exc() | |
| return f"์ฃ์กํฉ๋๋ค. ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" | |
| def respond(message, chat_history, selected_model): | |
| if not message.strip(): | |
| # If the message is empty, do nothing | |
| return chat_history, "" | |
| # Get the bot's response | |
| bot_message = chat_fn(message, chat_history, selected_model) | |
| # Update chat history | |
| chat_history.append([message, bot_message]) | |
| return chat_history, "" # Return updated history and clear the input box | |
| # --- Gradio Interface --- | |
| title = "Multi-Model Chatbot (with LoRA Support)" if PEFT_AVAILABLE else "Multi-Model Chatbot (Base Models Only)" | |
| with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"<h1><center>๐จ๏ธ {title}</center></h1>") | |
| gr.Markdown("<center>Select a model from the dropdown and start chatting. The app will load the model on the first message.</center>") | |
| with gr.Row(): | |
| model_select = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value=list(AVAILABLE_MODELS.keys())[0], # Default to the first model in the list | |
| label="Choose Model", | |
| interactive=True | |
| ) | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| label="Chat", | |
| show_copy_button=True, | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="์ฌ๊ธฐ์ ๋ฉ์์ง๋ฅผ ์ ๋ ฅํ์ธ์...", | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| # --- Event Handlers --- | |
| def clear_chat(): | |
| return [], "" | |
| # Send message on button click or enter key press | |
| send_btn.click( | |
| respond, | |
| inputs=[msg, chatbot, model_select], | |
| outputs=[chatbot, msg] | |
| ) | |
| msg.submit( | |
| respond, | |
| inputs=[msg, chatbot, model_select], | |
| outputs=[chatbot, msg] | |
| ) | |
| # Clear chat button | |
| clear_btn.click(clear_chat, outputs=[chatbot, msg]) | |
| if __name__ == "__main__": | |
| # Pre-load the default model to speed up the first interaction | |
| try: | |
| print("Pre-loading the default model...") | |
| default_model_name = list(AVAILABLE_MODELS.keys())[0] | |
| load_model(default_model_name) | |
| print("โ Default model pre-loaded successfully.") | |
| except Exception as e: | |
| print(f"โ ๏ธ Could not pre-load the default model: {e}") | |
| demo.launch( | |
| share=False, # Set to True to get a public link (on Hugging Face Spaces or Colab) | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |