llaa33219 commited on
Commit
3b44961
·
verified ·
1 Parent(s): f8cdca6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -48
app.py CHANGED
@@ -1,76 +1,163 @@
1
  import spaces
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
 
5
  # === List your models here ===
6
  MODEL_IDS = {
7
- "Entrystory-Qwen2.5-3b": "llaa33219/Entrystory-Qwen2.5-3b",
8
  # "Another‑Model": "username/another-model",
9
  # "Third‑Model": "username/third-model"
10
  }
11
 
12
- # Preload models & tokenizers (lazy loading to save memory on ZeroGPU)
13
- cached = {}
 
 
14
 
15
  def load_model(name):
16
- if name not in cached:
 
 
17
  print(f"Loading model: {name}")
18
- tok = AutoTokenizer.from_pretrained(MODEL_IDS[name])
19
- mod = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  MODEL_IDS[name],
21
- device_map="auto",
22
- torch_dtype="auto"
23
- ).eval()
24
- cached[name] = (tok, mod)
25
- return cached[name]
 
 
 
26
 
27
  @spaces.GPU()
28
  def chat_fn(message, history, selected_model):
29
- tokenizer, model = load_model(selected_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Build chat template (single‑turn for simplicity)
32
- messages = [{"role": "user", "content": message}]
33
- input_ids = tokenizer.apply_chat_template(
34
- conversation=messages,
35
- tokenize=True,
36
- add_generation_prompt=True,
37
- return_tensors="pt"
38
- ).to(model.device)
39
 
40
- output_ids = model.generate(input_ids, max_new_tokens=512)
41
- response = tokenizer.decode(
42
- output_ids[0][input_ids.shape[1]:],
43
- skip_special_tokens=True
44
- )
45
- return response
46
 
47
- with gr.Blocks(title="Multi‑Model Chat") as demo:
48
- gr.Markdown("# 🗨️ MultiModel Chatbot (ZeroGPU ready)")
 
49
 
50
- model_select = gr.Dropdown(
51
- list(MODEL_IDS.keys()),
52
- value=list(MODEL_IDS.keys())[0],
53
- label="Choose Model"
 
 
 
 
 
 
 
 
54
  )
55
 
56
- # Create chatbot and text input components
57
- chatbot = gr.Chatbot()
58
- msg = gr.Textbox(label="Message", placeholder="Type your message here...")
59
- clear = gr.ClearButton([msg, chatbot])
 
 
 
 
 
60
 
61
- def respond(message, chat_history):
62
- # Get the current selected model
63
- current_model = model_select.value
64
- bot_message = chat_fn(message, chat_history, current_model)
65
- chat_history.append((message, bot_message))
66
- return "", chat_history
67
 
68
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Alternative: Use ChatInterface with state management
71
- # chat = gr.ChatInterface(
72
- # fn=lambda msg, hist: chat_fn(msg, hist, model_select.value),
73
- # )
74
 
75
  if __name__ == "__main__":
76
- demo.launch()
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
 
6
  # === List your models here ===
7
  MODEL_IDS = {
8
+ "Qwen-Finetuned": "llaa33219/Entrystory-Qwen2.5-3b",
9
  # "Another‑Model": "username/another-model",
10
  # "Third‑Model": "username/third-model"
11
  }
12
 
13
+ # Global variables for model caching
14
+ current_model_name = None
15
+ current_tokenizer = None
16
+ current_model = None
17
 
18
  def load_model(name):
19
+ global current_model_name, current_tokenizer, current_model
20
+
21
+ if current_model_name != name:
22
  print(f"Loading model: {name}")
23
+
24
+ # Clear previous model from memory
25
+ if current_model is not None:
26
+ del current_model
27
+ torch.cuda.empty_cache()
28
+
29
+ # Load tokenizer
30
+ current_tokenizer = AutoTokenizer.from_pretrained(
31
+ MODEL_IDS[name],
32
+ trust_remote_code=True
33
+ )
34
+
35
+ # Add padding token if not present
36
+ if current_tokenizer.pad_token is None:
37
+ current_tokenizer.pad_token = current_tokenizer.eos_token
38
+
39
+ # Load model with ZeroGPU-friendly settings
40
+ current_model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_IDS[name],
42
+ torch_dtype=torch.float16, # Explicit dtype for ZeroGPU
43
+ trust_remote_code=True,
44
+ low_cpu_mem_usage=True
45
+ )
46
+
47
+ current_model_name = name
48
+
49
+ return current_tokenizer, current_model
50
 
51
  @spaces.GPU()
52
  def chat_fn(message, history, selected_model):
53
+ try:
54
+ tokenizer, model = load_model(selected_model)
55
+
56
+ # Move model to GPU inside the decorated function
57
+ model = model.cuda()
58
+
59
+ # Build conversation history for better context
60
+ conversation = []
61
+ for user_msg, bot_msg in history:
62
+ conversation.append({"role": "user", "content": user_msg})
63
+ conversation.append({"role": "assistant", "content": bot_msg})
64
+ conversation.append({"role": "user", "content": message})
65
+
66
+ # Apply chat template
67
+ input_ids = tokenizer.apply_chat_template(
68
+ conversation=conversation,
69
+ tokenize=True,
70
+ add_generation_prompt=True,
71
+ return_tensors="pt"
72
+ ).cuda()
73
+
74
+ # Generate response with proper settings
75
+ with torch.no_grad():
76
+ output_ids = model.generate(
77
+ input_ids,
78
+ max_new_tokens=512,
79
+ temperature=0.7,
80
+ do_sample=True,
81
+ pad_token_id=tokenizer.eos_token_id,
82
+ eos_token_id=tokenizer.eos_token_id,
83
+ use_cache=True
84
+ )
85
+
86
+ # Decode response
87
+ response = tokenizer.decode(
88
+ output_ids[0][input_ids.shape[1]:],
89
+ skip_special_tokens=True
90
+ ).strip()
91
+
92
+ return response
93
+
94
+ except Exception as e:
95
+ print(f"Error in chat_fn: {str(e)}")
96
+ return f"죄송합니다. 오류가 발생했습니다: {str(e)}"
97
+
98
+ def respond(message, chat_history, selected_model):
99
+ if not message.strip():
100
+ return chat_history, ""
101
 
102
+ # Get bot response
103
+ bot_message = chat_fn(message, chat_history, selected_model)
 
 
 
 
 
 
104
 
105
+ # Update chat history
106
+ chat_history.append([message, bot_message])
107
+
108
+ return chat_history, ""
 
 
109
 
110
+ # Create Gradio interface
111
+ with gr.Blocks(title="Multi-Model Chat", theme=gr.themes.Soft()) as demo:
112
+ gr.Markdown("# 🗨️ Multi-Model Chatbot (ZeroGPU ready)")
113
 
114
+ with gr.Row():
115
+ model_select = gr.Dropdown(
116
+ choices=list(MODEL_IDS.keys()),
117
+ value=list(MODEL_IDS.keys())[0],
118
+ label="Choose Model",
119
+ interactive=True
120
+ )
121
+
122
+ chatbot = gr.Chatbot(
123
+ height=400,
124
+ label="Chat",
125
+ show_copy_button=True
126
  )
127
 
128
+ with gr.Row():
129
+ msg = gr.Textbox(
130
+ label="Message",
131
+ placeholder="Type your message here...",
132
+ scale=4
133
+ )
134
+ send_btn = gr.Button("Send", scale=1, variant="primary")
135
+
136
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
137
 
138
+ # Event handlers
139
+ def clear_chat():
140
+ return [], ""
 
 
 
141
 
142
+ # Send message on button click or enter
143
+ send_btn.click(
144
+ respond,
145
+ inputs=[msg, chatbot, model_select],
146
+ outputs=[chatbot, msg]
147
+ )
148
+
149
+ msg.submit(
150
+ respond,
151
+ inputs=[msg, chatbot, model_select],
152
+ outputs=[chatbot, msg]
153
+ )
154
 
155
+ # Clear chat
156
+ clear_btn.click(clear_chat, outputs=[chatbot, msg])
 
 
157
 
158
  if __name__ == "__main__":
159
+ demo.launch(
160
+ share=False,
161
+ server_name="0.0.0.0",
162
+ server_port=7860
163
+ )