llaa33219 commited on
Commit
2742dc4
ยท
verified ยท
1 Parent(s): ff1e6ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -112
app.py CHANGED
@@ -2,20 +2,44 @@ import spaces
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
5
 
6
  # Try to import peft, if not available use base model only
7
  try:
8
  from peft import PeftModel
9
  PEFT_AVAILABLE = True
10
  except ImportError:
11
- print("Warning: peft not available, will use base model only")
12
  PEFT_AVAILABLE = False
13
 
14
- # === List your models here ===
15
- BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
16
- ADAPTER_MODELS = {
17
- "Qwen-Finetuned": "llaa33219/Entrystory-Qwen2.5-3b-Instruct",
18
- # ๋‹ค๋ฅธ ์–ด๋Œ‘ํ„ฐ๋“ค๋„ ์ถ”๊ฐ€ ๊ฐ€๋Šฅ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
21
  # Global variables for model caching
@@ -24,94 +48,98 @@ current_tokenizer = None
24
  current_model = None
25
 
26
  def load_model(name):
 
 
 
 
27
  global current_model_name, current_tokenizer, current_model
 
 
 
 
 
 
28
 
29
- if current_model_name != name:
30
- print(f"Loading model: {name}")
31
-
32
- # Clear previous model from memory
33
- if current_model is not None:
34
- del current_model
35
- torch.cuda.empty_cache()
36
-
37
- try:
38
- if PEFT_AVAILABLE:
39
- # LoRA adapter loading
40
- adapter_model_id = ADAPTER_MODELS[name]
41
-
42
- # Load tokenizer from adapter (has the right special tokens)
43
- current_tokenizer = AutoTokenizer.from_pretrained(
44
- adapter_model_id,
45
- trust_remote_code=True
46
- )
47
-
48
- # Add padding token if not present
49
- if current_tokenizer.pad_token is None:
50
- current_tokenizer.pad_token = current_tokenizer.eos_token
51
-
52
- # Load base model
53
- print(f"Loading base model: {BASE_MODEL}")
54
- base_model = AutoModelForCausalLM.from_pretrained(
55
- BASE_MODEL,
56
- torch_dtype=torch.float16,
57
- trust_remote_code=True,
58
- low_cpu_mem_usage=True
59
- )
60
-
61
- # Resize token embeddings to match the adapter's vocabulary size
62
- print(f"Original vocab size: {base_model.config.vocab_size}")
63
- print(f"Tokenizer vocab size: {len(current_tokenizer)}")
64
-
65
- if base_model.config.vocab_size != len(current_tokenizer):
66
- print(f"Resizing token embeddings from {base_model.config.vocab_size} to {len(current_tokenizer)}")
67
- base_model.resize_token_embeddings(len(current_tokenizer))
68
-
69
- # Load LoRA adapter with error handling
70
- print(f"Loading LoRA adapter: {adapter_model_id}")
71
- try:
72
- current_model = PeftModel.from_pretrained(
73
- base_model,
74
- adapter_model_id,
75
- torch_dtype=torch.float16
76
- )
77
-
78
- # Merge adapter with base model for better performance
79
- current_model = current_model.merge_and_unload()
80
- print(f"Successfully merged LoRA adapter")
81
-
82
- except Exception as adapter_error:
83
- print(f"Failed to load LoRA adapter: {adapter_error}")
84
- print("Falling back to base model only")
85
- current_model = base_model
86
-
87
- else:
88
- # Fallback to base model only
89
- print(f"peft not available, using base model only: {BASE_MODEL}")
90
- current_tokenizer = AutoTokenizer.from_pretrained(
91
- BASE_MODEL,
92
- trust_remote_code=True
93
- )
94
-
95
- # Add padding token if not present
96
- if current_tokenizer.pad_token is None:
97
- current_tokenizer.pad_token = current_tokenizer.eos_token
98
-
99
- current_model = AutoModelForCausalLM.from_pretrained(
100
- BASE_MODEL,
101
- torch_dtype=torch.float16,
102
- trust_remote_code=True,
103
- low_cpu_mem_usage=True
104
- )
105
 
106
- current_model_name = name
107
- print(f"Successfully loaded model: {name}")
108
 
109
- except Exception as e:
110
- print(f"Failed to load model {name}: {e}")
111
- import traceback
112
- traceback.print_exc()
113
- raise e
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return current_tokenizer, current_model
116
 
117
  @spaces.GPU()
@@ -119,18 +147,18 @@ def chat_fn(message, history, selected_model):
119
  try:
120
  tokenizer, model = load_model(selected_model)
121
 
122
- # Move model to GPU
123
  if not next(model.parameters()).is_cuda:
124
  model = model.cuda()
125
 
126
- # Build conversation history
127
  conversation = []
128
  for user_msg, bot_msg in history:
129
  conversation.append({"role": "user", "content": user_msg})
130
  conversation.append({"role": "assistant", "content": bot_msg})
131
  conversation.append({"role": "user", "content": message})
132
 
133
- # Apply chat template
134
  try:
135
  input_ids = tokenizer.apply_chat_template(
136
  conversation=conversation,
@@ -139,13 +167,15 @@ def chat_fn(message, history, selected_model):
139
  return_tensors="pt"
140
  ).cuda()
141
  except Exception as e:
142
- print(f"Chat template error: {e}")
143
- # Fallback to simple tokenization
144
  text = f"User: {message}\nAssistant:"
145
  input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
146
 
147
  # Generate response
148
  with torch.no_grad():
 
 
 
149
  output_ids = model.generate(
150
  input_ids,
151
  max_new_tokens=512,
@@ -154,10 +184,10 @@ def chat_fn(message, history, selected_model):
154
  pad_token_id=tokenizer.pad_token_id,
155
  eos_token_id=tokenizer.eos_token_id,
156
  use_cache=True,
157
- attention_mask=torch.ones_like(input_ids)
158
  )
159
 
160
- # Decode response
161
  response = tokenizer.decode(
162
  output_ids[0][input_ids.shape[1]:],
163
  skip_special_tokens=True
@@ -167,56 +197,58 @@ def chat_fn(message, history, selected_model):
167
 
168
  except Exception as e:
169
  print(f"Error in chat_fn: {str(e)}")
170
- import traceback
171
  traceback.print_exc()
172
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
173
 
174
  def respond(message, chat_history, selected_model):
175
  if not message.strip():
 
176
  return chat_history, ""
177
 
178
- # Get bot response
179
  bot_message = chat_fn(message, chat_history, selected_model)
180
 
181
  # Update chat history
182
  chat_history.append([message, bot_message])
183
 
184
- return chat_history, ""
185
 
186
- # Create Gradio interface
187
- title = "Multi-Model Chatbot (LoRA Adapter Support)" if PEFT_AVAILABLE else "Multi-Model Chatbot (Base Model Only)"
188
  with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
189
- gr.Markdown(f"# ๐Ÿ—จ๏ธ {title}")
 
190
 
191
  with gr.Row():
192
  model_select = gr.Dropdown(
193
- choices=list(ADAPTER_MODELS.keys()),
194
- value=list(ADAPTER_MODELS.keys())[0],
195
  label="Choose Model",
196
  interactive=True
197
  )
198
 
199
  chatbot = gr.Chatbot(
200
- height=400,
201
  label="Chat",
202
- show_copy_button=True
 
203
  )
204
 
205
  with gr.Row():
206
  msg = gr.Textbox(
207
  label="Message",
208
- placeholder="Type your message here...",
209
  scale=4
210
  )
211
  send_btn = gr.Button("Send", scale=1, variant="primary")
212
 
213
  clear_btn = gr.Button("Clear Chat", variant="secondary")
214
 
215
- # Event handlers
216
  def clear_chat():
217
  return [], ""
218
 
219
- # Send message on button click or enter
220
  send_btn.click(
221
  respond,
222
  inputs=[msg, chatbot, model_select],
@@ -229,12 +261,21 @@ with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
229
  outputs=[chatbot, msg]
230
  )
231
 
232
- # Clear chat
233
  clear_btn.click(clear_chat, outputs=[chatbot, msg])
234
 
235
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
236
  demo.launch(
237
- share=False,
238
  server_name="0.0.0.0",
239
  server_port=7860
240
  )
 
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ import traceback
6
 
7
  # Try to import peft, if not available use base model only
8
  try:
9
  from peft import PeftModel
10
  PEFT_AVAILABLE = True
11
  except ImportError:
12
+ print("Warning: peft library not found. LoRA adapters will not be available.")
13
  PEFT_AVAILABLE = False
14
 
15
+ # === Define all your available models here ===
16
+ # This new dictionary allows you to define both base models and LoRA adapters.
17
+ # 'type': can be 'base' for a standalone model or 'lora' for an adapter.
18
+ # 'id': the Hugging Face model/adapter ID.
19
+ # 'base_model_id': for LoRA adapters, specifies which base model to use.
20
+
21
+ AVAILABLE_MODELS = {
22
+ "BokantLM0.1-0.5b": {
23
+ "type": "base",
24
+ "id": "llaa33219/BokantLM0.1-0.5b",
25
+ },
26
+ "Entrystory-Qwen2.5-3b-Instruct": {
27
+ "type": "lora",
28
+ "id": "llaa33219/Entrystory-Qwen2.5-3b-Instruct",
29
+ "base_model_id": "Qwen/Qwen2.5-3B-Instruct" # This LoRA is based on the Qwen model
30
+ },
31
+ # --- You can add more models here ---
32
+ # Example of another base model:
33
+ # "Another Base Model (e.g., Ko-LLaMA)": {
34
+ # "type": "base",
35
+ # "id": "beomi/KoAlpaca-Polyglot-5.8B"
36
+ # },
37
+ # Example of another LoRA adapter:
38
+ # "Another LoRA Finetune": {
39
+ # "type": "lora",
40
+ # "id": "path/to/your/other-lora-adapter",
41
+ # "base_model_id": "Qwen/Qwen2.5-3B-Instruct"
42
+ # },
43
  }
44
 
45
  # Global variables for model caching
 
48
  current_model = None
49
 
50
  def load_model(name):
51
+ """
52
+ Loads a model based on the selection. It can load a base model directly
53
+ or load a base model and then apply a LoRA adapter to it.
54
+ """
55
  global current_model_name, current_tokenizer, current_model
56
+
57
+ if current_model_name == name:
58
+ # Model is already loaded, no need to do anything
59
+ return current_tokenizer, current_model
60
+
61
+ print(f"Switching to model: {name}")
62
 
63
+ # Clear previous model from memory
64
+ if current_model is not None:
65
+ del current_model
66
+ del current_tokenizer
67
+ current_model = None
68
+ current_tokenizer = None
69
+ torch.cuda.empty_cache()
70
+ print("Cleared previous model from memory.")
71
+
72
+ try:
73
+ model_info = AVAILABLE_MODELS[name]
74
+ model_type = model_info["type"]
75
+ model_id = model_info["id"]
76
+
77
+ # --- Case 1: Load a LoRA adapter model ---
78
+ if model_type == 'lora' and PEFT_AVAILABLE:
79
+ base_model_id = model_info["base_model_id"]
80
+ adapter_id = model_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ print(f"Loading LoRA model. Base: '{base_model_id}', Adapter: '{adapter_id}'")
 
83
 
84
+ # Load tokenizer from the adapter (it might have special tokens)
85
+ current_tokenizer = AutoTokenizer.from_pretrained(adapter_id, trust_remote_code=True)
86
+
87
+ # Load base model
88
+ base_model = AutoModelForCausalLM.from_pretrained(
89
+ base_model_id,
90
+ torch_dtype=torch.float16,
91
+ trust_remote_code=True,
92
+ low_cpu_mem_usage=True
93
+ )
94
+
95
+ # Resize token embeddings if the adapter's vocab differs from the base model's
96
+ if base_model.config.vocab_size != len(current_tokenizer):
97
+ print(f"Resizing token embeddings from {base_model.config.vocab_size} to {len(current_tokenizer)}")
98
+ base_model.resize_token_embeddings(len(current_tokenizer))
99
+
100
+ # Load and merge the LoRA adapter
101
+ print(f"Loading and merging LoRA adapter: {adapter_id}")
102
+ lora_model = PeftModel.from_pretrained(
103
+ base_model,
104
+ adapter_id,
105
+ torch_dtype=torch.float16
106
+ )
107
+ current_model = lora_model.merge_and_unload()
108
+ print("Successfully merged LoRA adapter.")
109
+
110
+ # --- Case 2: Load a base model directly ---
111
+ else:
112
+ if model_type == 'lora' and not PEFT_AVAILABLE:
113
+ print(f"PEFT not available. Cannot load LoRA adapter '{name}'. Falling back to its base model.")
114
+ # Fallback to the base model if PEFT is missing
115
+ model_id = model_info.get("base_model_id", list(AVAILABLE_MODELS.values())[0]['id'])
116
+
117
+ print(f"Loading base model: {model_id}")
118
+ current_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
119
+ current_model = AutoModelForCausalLM.from_pretrained(
120
+ model_id,
121
+ torch_dtype=torch.float16,
122
+ trust_remote_code=True,
123
+ low_cpu_mem_usage=True
124
+ )
125
+
126
+ # Common post-processing for any loaded model
127
+ if current_tokenizer.pad_token is None:
128
+ current_tokenizer.pad_token = current_tokenizer.eos_token
129
+ print("Set pad_token to eos_token.")
130
+
131
+ current_model_name = name
132
+ print(f"โœ… Successfully loaded model: {name}")
133
+
134
+ except Exception as e:
135
+ print(f"โŒ Failed to load model {name}: {e}")
136
+ traceback.print_exc()
137
+ # Clean up on failure
138
+ current_model_name = None
139
+ current_model = None
140
+ current_tokenizer = None
141
+ raise e # Re-raise the exception to be caught by the chat function
142
+
143
  return current_tokenizer, current_model
144
 
145
  @spaces.GPU()
 
147
  try:
148
  tokenizer, model = load_model(selected_model)
149
 
150
+ # Ensure model is on the correct device (GPU)
151
  if not next(model.parameters()).is_cuda:
152
  model = model.cuda()
153
 
154
+ # Build conversation history for the chat template
155
  conversation = []
156
  for user_msg, bot_msg in history:
157
  conversation.append({"role": "user", "content": user_msg})
158
  conversation.append({"role": "assistant", "content": bot_msg})
159
  conversation.append({"role": "user", "content": message})
160
 
161
+ # Apply the model's specific chat template
162
  try:
163
  input_ids = tokenizer.apply_chat_template(
164
  conversation=conversation,
 
167
  return_tensors="pt"
168
  ).cuda()
169
  except Exception as e:
170
+ print(f"Chat template error: {e}. Falling back to simple encoding.")
 
171
  text = f"User: {message}\nAssistant:"
172
  input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
173
 
174
  # Generate response
175
  with torch.no_grad():
176
+ # Create attention mask
177
+ attention_mask = torch.ones_like(input_ids)
178
+
179
  output_ids = model.generate(
180
  input_ids,
181
  max_new_tokens=512,
 
184
  pad_token_id=tokenizer.pad_token_id,
185
  eos_token_id=tokenizer.eos_token_id,
186
  use_cache=True,
187
+ attention_mask=attention_mask
188
  )
189
 
190
+ # Decode the generated tokens into text, skipping the prompt
191
  response = tokenizer.decode(
192
  output_ids[0][input_ids.shape[1]:],
193
  skip_special_tokens=True
 
197
 
198
  except Exception as e:
199
  print(f"Error in chat_fn: {str(e)}")
 
200
  traceback.print_exc()
201
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
202
 
203
  def respond(message, chat_history, selected_model):
204
  if not message.strip():
205
+ # If the message is empty, do nothing
206
  return chat_history, ""
207
 
208
+ # Get the bot's response
209
  bot_message = chat_fn(message, chat_history, selected_model)
210
 
211
  # Update chat history
212
  chat_history.append([message, bot_message])
213
 
214
+ return chat_history, "" # Return updated history and clear the input box
215
 
216
+ # --- Gradio Interface ---
217
+ title = "Multi-Model Chatbot (with LoRA Support)" if PEFT_AVAILABLE else "Multi-Model Chatbot (Base Models Only)"
218
  with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
219
+ gr.Markdown(f"<h1><center>๐Ÿ—จ๏ธ {title}</center></h1>")
220
+ gr.Markdown("<center>Select a model from the dropdown and start chatting. The app will load the model on the first message.</center>")
221
 
222
  with gr.Row():
223
  model_select = gr.Dropdown(
224
+ choices=list(AVAILABLE_MODELS.keys()),
225
+ value=list(AVAILABLE_MODELS.keys())[0], # Default to the first model in the list
226
  label="Choose Model",
227
  interactive=True
228
  )
229
 
230
  chatbot = gr.Chatbot(
231
+ height=500,
232
  label="Chat",
233
+ show_copy_button=True,
234
+ bubble_full_width=False
235
  )
236
 
237
  with gr.Row():
238
  msg = gr.Textbox(
239
  label="Message",
240
+ placeholder="์—ฌ๊ธฐ์— ๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
241
  scale=4
242
  )
243
  send_btn = gr.Button("Send", scale=1, variant="primary")
244
 
245
  clear_btn = gr.Button("Clear Chat", variant="secondary")
246
 
247
+ # --- Event Handlers ---
248
  def clear_chat():
249
  return [], ""
250
 
251
+ # Send message on button click or enter key press
252
  send_btn.click(
253
  respond,
254
  inputs=[msg, chatbot, model_select],
 
261
  outputs=[chatbot, msg]
262
  )
263
 
264
+ # Clear chat button
265
  clear_btn.click(clear_chat, outputs=[chatbot, msg])
266
 
267
  if __name__ == "__main__":
268
+ # Pre-load the default model to speed up the first interaction
269
+ try:
270
+ print("Pre-loading the default model...")
271
+ default_model_name = list(AVAILABLE_MODELS.keys())[0]
272
+ load_model(default_model_name)
273
+ print("โœ… Default model pre-loaded successfully.")
274
+ except Exception as e:
275
+ print(f"โš ๏ธ Could not pre-load the default model: {e}")
276
+
277
  demo.launch(
278
+ share=False, # Set to True to get a public link (on Hugging Face Spaces or Colab)
279
  server_name="0.0.0.0",
280
  server_port=7860
281
  )