File size: 10,557 Bytes
a60d35e
170d17b
1cb0c6f
3b44961
2742dc4
1cb0c6f
ff1e6ca
 
 
 
 
2742dc4
ff1e6ca
 
2742dc4
 
 
 
 
 
 
ecc62d2
2742dc4
ecc62d2
2742dc4
41d8a8c
 
 
9625acd
 
764f0dc
7d4566f
764f0dc
2742dc4
129d136
899afe7
129d136
b485614
 
 
 
523c327
c9719ed
4455d33
ff03e95
4455d33
2742dc4
 
 
 
 
 
 
 
 
 
 
 
1cb0c6f
 
3b44961
 
 
 
1cb0c6f
 
2742dc4
 
 
 
3b44961
2742dc4
 
 
 
 
 
3b44961
2742dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1317fd7
2742dc4
1317fd7
2742dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b44961
1cb0c6f
a60d35e
1cb0c6f
3b44961
 
 
2742dc4
1317fd7
 
3b44961
2742dc4
3b44961
 
 
 
 
 
2742dc4
1317fd7
 
 
 
 
 
 
 
2742dc4
1317fd7
 
3b44961
1317fd7
3b44961
2742dc4
 
 
3b44961
 
1ca38b2
3b44961
 
1317fd7
3b44961
1317fd7
2742dc4
3b44961
 
2742dc4
3b44961
 
 
 
 
 
 
 
 
1317fd7
3b44961
 
 
 
2742dc4
3b44961
7717840
2742dc4
3b44961
7717840
3b44961
 
 
2742dc4
1cb0c6f
2742dc4
 
3b44961
2742dc4
 
7717840
3b44961
 
2742dc4
 
3b44961
 
 
 
 
2742dc4
3b44961
2742dc4
 
1cb0c6f
7717840
3b44961
 
 
2742dc4
3b44961
 
 
 
 
7717840
2742dc4
3b44961
 
7717840
2742dc4
3b44961
 
 
 
 
 
 
 
 
 
 
7717840
2742dc4
3b44961
170d17b
 
2742dc4
 
 
 
 
 
 
 
 
3b44961
2742dc4
3b44961
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import traceback

# Try to import peft, if not available use base model only
try:
    from peft import PeftModel
    PEFT_AVAILABLE = True
except ImportError:
    print("Warning: peft library not found. LoRA adapters will not be available.")
    PEFT_AVAILABLE = False

# === Define all your available models here ===
# This new dictionary allows you to define both base models and LoRA adapters.
# 'type': can be 'base' for a standalone model or 'lora' for an adapter.
# 'id': the Hugging Face model/adapter ID.
# 'base_model_id': for LoRA adapters, specifies which base model to use.

AVAILABLE_MODELS = {
    "BokantLM0.1-0.5B": {
        "type": "base",
        "id": "llaa33219/BokantLM0.1-0.5B",
    },
        "BokantLM0.1-135M-Deepseek": {
        "type": "base",
        "id": "llaa33219/BokantLM0.1-135M-Deepseek",
    },
        "BokantLM0.1-135M-claude-3.7-sonnet": {
        "type": "lora",
        "id": "llaa33219/BokantLM0.1-135M-SmolLM2-135M-LoRA-claude-3.7-sonnet",
        "base_model_id": "HuggingFaceTB/SmolLM2-135M"
    },
        "Vere1Ko-360M": {
        "type": "base",
        "id": "llaa33219/Vere1Ko-360M"
    },
        "Vere1Ko-0.6B": {
        "type": "base",
        "id": "llaa33219/Vere1Ko-0.6B"
    },
    "Solar-Open-100B-Tmesis": {
        "type": "base",
        "id": "llaa33219/Solar-Open-100B-pruned-20pct"
    },
    # --- You can add more models here ---
    # Example of another base model:
    # "Another Base Model (e.g., Ko-LLaMA)": {
    #     "type": "base",
    #     "id": "beomi/KoAlpaca-Polyglot-5.8B"
    # },
    # Example of another LoRA adapter:
    # "Another LoRA Finetune": {
    #     "type": "lora",
    #     "id": "path/to/your/other-lora-adapter",
    #     "base_model_id": "Qwen/Qwen2.5-3B-Instruct"
    # },
}

# Global variables for model caching
current_model_name = None
current_tokenizer = None
current_model = None

def load_model(name):
    """
    Loads a model based on the selection. It can load a base model directly
    or load a base model and then apply a LoRA adapter to it.
    """
    global current_model_name, current_tokenizer, current_model

    if current_model_name == name:
        # Model is already loaded, no need to do anything
        return current_tokenizer, current_model

    print(f"Switching to model: {name}")
    
    # Clear previous model from memory
    if current_model is not None:
        del current_model
        del current_tokenizer
        current_model = None
        current_tokenizer = None
        torch.cuda.empty_cache()
        print("Cleared previous model from memory.")

    try:
        model_info = AVAILABLE_MODELS[name]
        model_type = model_info["type"]
        model_id = model_info["id"]

        # --- Case 1: Load a LoRA adapter model ---
        if model_type == 'lora' and PEFT_AVAILABLE:
            base_model_id = model_info["base_model_id"]
            adapter_id = model_id
            
            print(f"Loading LoRA model. Base: '{base_model_id}', Adapter: '{adapter_id}'")
            
            # Load tokenizer from the adapter (it might have special tokens)
            current_tokenizer = AutoTokenizer.from_pretrained(adapter_id, trust_remote_code=True)
            
            # Load base model
            base_model = AutoModelForCausalLM.from_pretrained(
                base_model_id,
                torch_dtype=torch.float16,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            # Resize token embeddings if the adapter's vocab differs from the base model's
            if base_model.config.vocab_size != len(current_tokenizer):
                print(f"Resizing token embeddings from {base_model.config.vocab_size} to {len(current_tokenizer)}")
                base_model.resize_token_embeddings(len(current_tokenizer))
            
            # Load and merge the LoRA adapter
            print(f"Loading and merging LoRA adapter: {adapter_id}")
            lora_model = PeftModel.from_pretrained(
                base_model,
                adapter_id,
                torch_dtype=torch.float16
            )
            current_model = lora_model.merge_and_unload()
            print("Successfully merged LoRA adapter.")

        # --- Case 2: Load a base model directly ---
        else:
            if model_type == 'lora' and not PEFT_AVAILABLE:
                print(f"PEFT not available. Cannot load LoRA adapter '{name}'. Falling back to its base model.")
                # Fallback to the base model if PEFT is missing
                model_id = model_info.get("base_model_id", list(AVAILABLE_MODELS.values())[0]['id'])

            print(f"Loading base model: {model_id}")
            current_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
            current_model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.float16,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )

        # Common post-processing for any loaded model
        if current_tokenizer.pad_token is None:
            current_tokenizer.pad_token = current_tokenizer.eos_token
            print("Set pad_token to eos_token.")

        current_model_name = name
        print(f"โœ… Successfully loaded model: {name}")

    except Exception as e:
        print(f"โŒ Failed to load model {name}: {e}")
        traceback.print_exc()
        # Clean up on failure
        current_model_name = None
        current_model = None
        current_tokenizer = None
        raise e # Re-raise the exception to be caught by the chat function

    return current_tokenizer, current_model

@spaces.GPU()
def chat_fn(message, history, selected_model):
    try:
        tokenizer, model = load_model(selected_model)
        
        # Ensure model is on the correct device (GPU)
        if not next(model.parameters()).is_cuda:
            model = model.cuda()
        
        # Build conversation history for the chat template
        conversation = []
        for user_msg, bot_msg in history:
            conversation.append({"role": "user", "content": user_msg})
            conversation.append({"role": "assistant", "content": bot_msg})
        conversation.append({"role": "user", "content": message})
        
        # Apply the model's specific chat template
        try:
            input_ids = tokenizer.apply_chat_template(
                conversation=conversation,
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt"
            ).cuda()
        except Exception as e:
            print(f"Chat template error: {e}. Falling back to simple encoding.")
            text = f"User: {message}\nAssistant:"
            input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
        
        # Generate response
        with torch.no_grad():
            # Create attention mask
            attention_mask = torch.ones_like(input_ids)
            
            output_ids = model.generate(
                input_ids,
                max_new_tokens=4096,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                use_cache=True,
                attention_mask=attention_mask
            )
        
        # Decode the generated tokens into text, skipping the prompt
        response = tokenizer.decode(
            output_ids[0][input_ids.shape[1]:],
            skip_special_tokens=True
        ).strip()
        
        return response
        
    except Exception as e:
        print(f"Error in chat_fn: {str(e)}")
        traceback.print_exc()
        return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"

def respond(message, chat_history, selected_model):
    if not message.strip():
        # If the message is empty, do nothing
        return chat_history, ""
    
    # Get the bot's response
    bot_message = chat_fn(message, chat_history, selected_model)
    
    # Update chat history
    chat_history.append([message, bot_message])
    
    return chat_history, "" # Return updated history and clear the input box

# --- Gradio Interface ---
title = "Multi-Model Chatbot (with LoRA Support)" if PEFT_AVAILABLE else "Multi-Model Chatbot (Base Models Only)"
with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"<h1><center>๐Ÿ—จ๏ธ {title}</center></h1>")
    gr.Markdown("<center>Select a model from the dropdown and start chatting. The app will load the model on the first message.</center>")
    
    with gr.Row():
        model_select = gr.Dropdown(
            choices=list(AVAILABLE_MODELS.keys()),
            value=list(AVAILABLE_MODELS.keys())[0], # Default to the first model in the list
            label="Choose Model",
            interactive=True
        )
    
    chatbot = gr.Chatbot(
        height=500,
        label="Chat",
        show_copy_button=True,
        bubble_full_width=False
    )
    
    with gr.Row():
        msg = gr.Textbox(
            label="Message",
            placeholder="์—ฌ๊ธฐ์— ๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
            scale=4
        )
        send_btn = gr.Button("Send", scale=1, variant="primary")
    
    clear_btn = gr.Button("Clear Chat", variant="secondary")
    
    # --- Event Handlers ---
    def clear_chat():
        return [], ""
    
    # Send message on button click or enter key press
    send_btn.click(
        respond,
        inputs=[msg, chatbot, model_select],
        outputs=[chatbot, msg]
    )
    
    msg.submit(
        respond,
        inputs=[msg, chatbot, model_select],
        outputs=[chatbot, msg]
    )
    
    # Clear chat button
    clear_btn.click(clear_chat, outputs=[chatbot, msg])

if __name__ == "__main__":
    # Pre-load the default model to speed up the first interaction
    try:
        print("Pre-loading the default model...")
        default_model_name = list(AVAILABLE_MODELS.keys())[0]
        load_model(default_model_name)
        print("โœ… Default model pre-loaded successfully.")
    except Exception as e:
        print(f"โš ๏ธ Could not pre-load the default model: {e}")

    demo.launch(
        share=False, # Set to True to get a public link (on Hugging Face Spaces or Colab)
        server_name="0.0.0.0",
        server_port=7860
    )