llaa33219's picture
Update app.py
ff03e95 verified
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import traceback
# Try to import peft, if not available use base model only
try:
from peft import PeftModel
PEFT_AVAILABLE = True
except ImportError:
print("Warning: peft library not found. LoRA adapters will not be available.")
PEFT_AVAILABLE = False
# === Define all your available models here ===
# This new dictionary allows you to define both base models and LoRA adapters.
# 'type': can be 'base' for a standalone model or 'lora' for an adapter.
# 'id': the Hugging Face model/adapter ID.
# 'base_model_id': for LoRA adapters, specifies which base model to use.
AVAILABLE_MODELS = {
"BokantLM0.1-0.5B": {
"type": "base",
"id": "llaa33219/BokantLM0.1-0.5B",
},
"BokantLM0.1-135M-Deepseek": {
"type": "base",
"id": "llaa33219/BokantLM0.1-135M-Deepseek",
},
"BokantLM0.1-135M-claude-3.7-sonnet": {
"type": "lora",
"id": "llaa33219/BokantLM0.1-135M-SmolLM2-135M-LoRA-claude-3.7-sonnet",
"base_model_id": "HuggingFaceTB/SmolLM2-135M"
},
"Vere1Ko-360M": {
"type": "base",
"id": "llaa33219/Vere1Ko-360M"
},
"Vere1Ko-0.6B": {
"type": "base",
"id": "llaa33219/Vere1Ko-0.6B"
},
"Solar-Open-100B-Tmesis": {
"type": "base",
"id": "llaa33219/Solar-Open-100B-pruned-20pct"
},
# --- You can add more models here ---
# Example of another base model:
# "Another Base Model (e.g., Ko-LLaMA)": {
# "type": "base",
# "id": "beomi/KoAlpaca-Polyglot-5.8B"
# },
# Example of another LoRA adapter:
# "Another LoRA Finetune": {
# "type": "lora",
# "id": "path/to/your/other-lora-adapter",
# "base_model_id": "Qwen/Qwen2.5-3B-Instruct"
# },
}
# Global variables for model caching
current_model_name = None
current_tokenizer = None
current_model = None
def load_model(name):
"""
Loads a model based on the selection. It can load a base model directly
or load a base model and then apply a LoRA adapter to it.
"""
global current_model_name, current_tokenizer, current_model
if current_model_name == name:
# Model is already loaded, no need to do anything
return current_tokenizer, current_model
print(f"Switching to model: {name}")
# Clear previous model from memory
if current_model is not None:
del current_model
del current_tokenizer
current_model = None
current_tokenizer = None
torch.cuda.empty_cache()
print("Cleared previous model from memory.")
try:
model_info = AVAILABLE_MODELS[name]
model_type = model_info["type"]
model_id = model_info["id"]
# --- Case 1: Load a LoRA adapter model ---
if model_type == 'lora' and PEFT_AVAILABLE:
base_model_id = model_info["base_model_id"]
adapter_id = model_id
print(f"Loading LoRA model. Base: '{base_model_id}', Adapter: '{adapter_id}'")
# Load tokenizer from the adapter (it might have special tokens)
current_tokenizer = AutoTokenizer.from_pretrained(adapter_id, trust_remote_code=True)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Resize token embeddings if the adapter's vocab differs from the base model's
if base_model.config.vocab_size != len(current_tokenizer):
print(f"Resizing token embeddings from {base_model.config.vocab_size} to {len(current_tokenizer)}")
base_model.resize_token_embeddings(len(current_tokenizer))
# Load and merge the LoRA adapter
print(f"Loading and merging LoRA adapter: {adapter_id}")
lora_model = PeftModel.from_pretrained(
base_model,
adapter_id,
torch_dtype=torch.float16
)
current_model = lora_model.merge_and_unload()
print("Successfully merged LoRA adapter.")
# --- Case 2: Load a base model directly ---
else:
if model_type == 'lora' and not PEFT_AVAILABLE:
print(f"PEFT not available. Cannot load LoRA adapter '{name}'. Falling back to its base model.")
# Fallback to the base model if PEFT is missing
model_id = model_info.get("base_model_id", list(AVAILABLE_MODELS.values())[0]['id'])
print(f"Loading base model: {model_id}")
current_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
current_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Common post-processing for any loaded model
if current_tokenizer.pad_token is None:
current_tokenizer.pad_token = current_tokenizer.eos_token
print("Set pad_token to eos_token.")
current_model_name = name
print(f"โœ… Successfully loaded model: {name}")
except Exception as e:
print(f"โŒ Failed to load model {name}: {e}")
traceback.print_exc()
# Clean up on failure
current_model_name = None
current_model = None
current_tokenizer = None
raise e # Re-raise the exception to be caught by the chat function
return current_tokenizer, current_model
@spaces.GPU()
def chat_fn(message, history, selected_model):
try:
tokenizer, model = load_model(selected_model)
# Ensure model is on the correct device (GPU)
if not next(model.parameters()).is_cuda:
model = model.cuda()
# Build conversation history for the chat template
conversation = []
for user_msg, bot_msg in history:
conversation.append({"role": "user", "content": user_msg})
conversation.append({"role": "assistant", "content": bot_msg})
conversation.append({"role": "user", "content": message})
# Apply the model's specific chat template
try:
input_ids = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).cuda()
except Exception as e:
print(f"Chat template error: {e}. Falling back to simple encoding.")
text = f"User: {message}\nAssistant:"
input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
# Generate response
with torch.no_grad():
# Create attention mask
attention_mask = torch.ones_like(input_ids)
output_ids = model.generate(
input_ids,
max_new_tokens=4096,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
attention_mask=attention_mask
)
# Decode the generated tokens into text, skipping the prompt
response = tokenizer.decode(
output_ids[0][input_ids.shape[1]:],
skip_special_tokens=True
).strip()
return response
except Exception as e:
print(f"Error in chat_fn: {str(e)}")
traceback.print_exc()
return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
def respond(message, chat_history, selected_model):
if not message.strip():
# If the message is empty, do nothing
return chat_history, ""
# Get the bot's response
bot_message = chat_fn(message, chat_history, selected_model)
# Update chat history
chat_history.append([message, bot_message])
return chat_history, "" # Return updated history and clear the input box
# --- Gradio Interface ---
title = "Multi-Model Chatbot (with LoRA Support)" if PEFT_AVAILABLE else "Multi-Model Chatbot (Base Models Only)"
with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"<h1><center>๐Ÿ—จ๏ธ {title}</center></h1>")
gr.Markdown("<center>Select a model from the dropdown and start chatting. The app will load the model on the first message.</center>")
with gr.Row():
model_select = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value=list(AVAILABLE_MODELS.keys())[0], # Default to the first model in the list
label="Choose Model",
interactive=True
)
chatbot = gr.Chatbot(
height=500,
label="Chat",
show_copy_button=True,
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="์—ฌ๊ธฐ์— ๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
scale=4
)
send_btn = gr.Button("Send", scale=1, variant="primary")
clear_btn = gr.Button("Clear Chat", variant="secondary")
# --- Event Handlers ---
def clear_chat():
return [], ""
# Send message on button click or enter key press
send_btn.click(
respond,
inputs=[msg, chatbot, model_select],
outputs=[chatbot, msg]
)
msg.submit(
respond,
inputs=[msg, chatbot, model_select],
outputs=[chatbot, msg]
)
# Clear chat button
clear_btn.click(clear_chat, outputs=[chatbot, msg])
if __name__ == "__main__":
# Pre-load the default model to speed up the first interaction
try:
print("Pre-loading the default model...")
default_model_name = list(AVAILABLE_MODELS.keys())[0]
load_model(default_model_name)
print("โœ… Default model pre-loaded successfully.")
except Exception as e:
print(f"โš ๏ธ Could not pre-load the default model: {e}")
demo.launch(
share=False, # Set to True to get a public link (on Hugging Face Spaces or Colab)
server_name="0.0.0.0",
server_port=7860
)