llaa33219 commited on
Commit
1317fd7
ยท
verified ยท
1 Parent(s): beab2d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -39
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import spaces
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import torch
5
 
6
  # === List your models here ===
7
- MODEL_IDS = {
8
- "Entrystory-Qwen2.5-3b-Instruct": {
9
- "base": "Qwen/Qwen2.5-3B-Instruct",
10
- "adapter": "llaa33219/Entrystory-Qwen2.5-3b-Instruct"
11
- },
12
  }
13
 
14
  # Global variables for model caching
@@ -27,25 +27,47 @@ def load_model(name):
27
  del current_model
28
  torch.cuda.empty_cache()
29
 
30
- # Load tokenizer
31
- current_tokenizer = AutoTokenizer.from_pretrained(
32
- MODEL_IDS[name],
33
- trust_remote_code=True
34
- )
35
-
36
- # Add padding token if not present
37
- if current_tokenizer.pad_token is None:
38
- current_tokenizer.pad_token = current_tokenizer.eos_token
39
-
40
- # Load model with ZeroGPU-friendly settings
41
- current_model = AutoModelForCausalLM.from_pretrained(
42
- MODEL_IDS[name],
43
- torch_dtype=torch.float16, # Explicit dtype for ZeroGPU
44
- trust_remote_code=True,
45
- low_cpu_mem_usage=True
46
- )
47
-
48
- current_model_name = name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  return current_tokenizer, current_model
51
 
@@ -54,10 +76,11 @@ def chat_fn(message, history, selected_model):
54
  try:
55
  tokenizer, model = load_model(selected_model)
56
 
57
- # Move model to GPU inside the decorated function
58
- model = model.cuda()
 
59
 
60
- # Build conversation history for better context
61
  conversation = []
62
  for user_msg, bot_msg in history:
63
  conversation.append({"role": "user", "content": user_msg})
@@ -65,23 +88,30 @@ def chat_fn(message, history, selected_model):
65
  conversation.append({"role": "user", "content": message})
66
 
67
  # Apply chat template
68
- input_ids = tokenizer.apply_chat_template(
69
- conversation=conversation,
70
- tokenize=True,
71
- add_generation_prompt=True,
72
- return_tensors="pt"
73
- ).cuda()
 
 
 
 
 
 
74
 
75
- # Generate response with proper settings
76
  with torch.no_grad():
77
  output_ids = model.generate(
78
  input_ids,
79
  max_new_tokens=512,
80
  temperature=0.7,
81
  do_sample=True,
82
- pad_token_id=tokenizer.eos_token_id,
83
  eos_token_id=tokenizer.eos_token_id,
84
- use_cache=True
 
85
  )
86
 
87
  # Decode response
@@ -94,6 +124,8 @@ def chat_fn(message, history, selected_model):
94
 
95
  except Exception as e:
96
  print(f"Error in chat_fn: {str(e)}")
 
 
97
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
98
 
99
  def respond(message, chat_history, selected_model):
@@ -110,12 +142,12 @@ def respond(message, chat_history, selected_model):
110
 
111
  # Create Gradio interface
112
  with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
113
- gr.Markdown("# ๐Ÿ—จ๏ธ Multi-Model Chatbot (ZeroGPU ready)")
114
 
115
  with gr.Row():
116
  model_select = gr.Dropdown(
117
- choices=list(MODEL_IDS.keys()),
118
- value=list(MODEL_IDS.keys())[0],
119
  label="Choose Model",
120
  interactive=True
121
  )
 
1
  import spaces
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
  import torch
6
 
7
  # === List your models here ===
8
+ BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
9
+ ADAPTER_MODELS = {
10
+ "Qwen-Finetuned": "llaa33219/Entrystory-Qwen2.5-3b",
11
+ # ๋‹ค๋ฅธ ์–ด๋Œ‘ํ„ฐ๋“ค๋„ ์ถ”๊ฐ€ ๊ฐ€๋Šฅ
 
12
  }
13
 
14
  # Global variables for model caching
 
27
  del current_model
28
  torch.cuda.empty_cache()
29
 
30
+ try:
31
+ adapter_model_id = ADAPTER_MODELS[name]
32
+
33
+ # Load tokenizer from adapter (has the right special tokens)
34
+ current_tokenizer = AutoTokenizer.from_pretrained(
35
+ adapter_model_id,
36
+ trust_remote_code=True
37
+ )
38
+
39
+ # Add padding token if not present
40
+ if current_tokenizer.pad_token is None:
41
+ current_tokenizer.pad_token = current_tokenizer.eos_token
42
+
43
+ # Load base model
44
+ print(f"Loading base model: {BASE_MODEL}")
45
+ base_model = AutoModelForCausalLM.from_pretrained(
46
+ BASE_MODEL,
47
+ torch_dtype=torch.float16,
48
+ trust_remote_code=True,
49
+ low_cpu_mem_usage=True
50
+ )
51
+
52
+ # Load LoRA adapter
53
+ print(f"Loading LoRA adapter: {adapter_model_id}")
54
+ current_model = PeftModel.from_pretrained(
55
+ base_model,
56
+ adapter_model_id,
57
+ torch_dtype=torch.float16
58
+ )
59
+
60
+ # Merge adapter with base model for better performance
61
+ current_model = current_model.merge_and_unload()
62
+
63
+ current_model_name = name
64
+ print(f"Successfully loaded model: {name}")
65
+
66
+ except Exception as e:
67
+ print(f"Failed to load model {name}: {e}")
68
+ import traceback
69
+ traceback.print_exc()
70
+ raise e
71
 
72
  return current_tokenizer, current_model
73
 
 
76
  try:
77
  tokenizer, model = load_model(selected_model)
78
 
79
+ # Move model to GPU
80
+ if not next(model.parameters()).is_cuda:
81
+ model = model.cuda()
82
 
83
+ # Build conversation history
84
  conversation = []
85
  for user_msg, bot_msg in history:
86
  conversation.append({"role": "user", "content": user_msg})
 
88
  conversation.append({"role": "user", "content": message})
89
 
90
  # Apply chat template
91
+ try:
92
+ input_ids = tokenizer.apply_chat_template(
93
+ conversation=conversation,
94
+ tokenize=True,
95
+ add_generation_prompt=True,
96
+ return_tensors="pt"
97
+ ).cuda()
98
+ except Exception as e:
99
+ print(f"Chat template error: {e}")
100
+ # Fallback to simple tokenization
101
+ text = f"User: {message}\nAssistant:"
102
+ input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
103
 
104
+ # Generate response
105
  with torch.no_grad():
106
  output_ids = model.generate(
107
  input_ids,
108
  max_new_tokens=512,
109
  temperature=0.7,
110
  do_sample=True,
111
+ pad_token_id=tokenizer.pad_token_id,
112
  eos_token_id=tokenizer.eos_token_id,
113
+ use_cache=True,
114
+ attention_mask=torch.ones_like(input_ids)
115
  )
116
 
117
  # Decode response
 
124
 
125
  except Exception as e:
126
  print(f"Error in chat_fn: {str(e)}")
127
+ import traceback
128
+ traceback.print_exc()
129
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
130
 
131
  def respond(message, chat_history, selected_model):
 
142
 
143
  # Create Gradio interface
144
  with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
145
+ gr.Markdown("# ๐Ÿ—จ๏ธ Multi-Model Chatbot (LoRA Adapter Support)")
146
 
147
  with gr.Row():
148
  model_select = gr.Dropdown(
149
+ choices=list(ADAPTER_MODELS.keys()),
150
+ value=list(ADAPTER_MODELS.keys())[0],
151
  label="Choose Model",
152
  interactive=True
153
  )