Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/sunbv56/V-LegalQA-Chatbot
Browse files
app.py
CHANGED
|
@@ -1,23 +1,17 @@
|
|
| 1 |
-
import uuid
|
| 2 |
-
import time
|
| 3 |
-
import json
|
| 4 |
-
import gradio as gr
|
| 5 |
import torch
|
| 6 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 7 |
-
import
|
| 8 |
-
import modelscope_studio.components.antdx as antdx
|
| 9 |
-
import modelscope_studio.components.base as ms
|
| 10 |
-
import modelscope_studio.components.pro as pro
|
| 11 |
-
# Removed: import dashscope
|
| 12 |
-
from config import DEFAULT_LOCALE, DEFAULT_SETTINGS, DEFAULT_THEME, DEFAULT_SUGGESTIONS, save_history, get_text, user_config, bot_config, welcome_config #, api_key # Removed api_key
|
| 13 |
-
# Removed: from dashscope import Generation
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
device
|
| 18 |
-
print(f"Using device: {device}")
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Sử dụng try-except để xử lý lỗi nếu không tải được mô hình
|
| 23 |
try:
|
|
@@ -25,7 +19,6 @@ try:
|
|
| 25 |
print(f"Loading model: {model_name_1}...")
|
| 26 |
model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_1).to(device)
|
| 27 |
tokenizer_1 = AutoTokenizer.from_pretrained(model_name_1)
|
| 28 |
-
loaded_models[model_name_1] = {"model": model_1, "tokenizer": tokenizer_1}
|
| 29 |
print(f"Model {model_name_1} loaded successfully.")
|
| 30 |
except Exception as e:
|
| 31 |
print(f"Error loading model {model_name_1}: {e}")
|
|
@@ -35,1022 +28,200 @@ try:
|
|
| 35 |
print(f"Loading model: {model_name_2}...")
|
| 36 |
model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_2).to(device)
|
| 37 |
tokenizer_2 = AutoTokenizer.from_pretrained(model_name_2)
|
| 38 |
-
loaded_models[model_name_2] = {"model": model_2, "tokenizer": tokenizer_2}
|
| 39 |
print(f"Model {model_name_2} loaded successfully.")
|
| 40 |
except Exception as e:
|
| 41 |
print(f"Error loading model {model_name_2}: {e}")
|
| 42 |
|
| 43 |
# Bỏ qua việc tải model_3 (ViLawT5_RL)
|
|
|
|
| 44 |
|
| 45 |
try:
|
| 46 |
model_name_4 = "sunbv56/V-LegalQA"
|
| 47 |
print(f"Loading model: {model_name_4}...")
|
| 48 |
model_4 = AutoModelForSeq2SeqLM.from_pretrained(model_name_4).to(device)
|
| 49 |
tokenizer_4 = AutoTokenizer.from_pretrained(model_name_4)
|
| 50 |
-
loaded_models[model_name_4] = {"model": model_4, "tokenizer": tokenizer_4}
|
| 51 |
print(f"Model {model_name_4} loaded successfully.")
|
| 52 |
except Exception as e:
|
| 53 |
print(f"Error loading model {model_name_4}: {e}")
|
| 54 |
|
| 55 |
-
if not loaded_models:
|
| 56 |
-
print("\n" + "="*50)
|
| 57 |
-
print("FATAL ERROR: No models could be loaded. The application cannot run.")
|
| 58 |
-
print("Please check model names, network connection, and available disk space.")
|
| 59 |
-
print("="*50 + "\n")
|
| 60 |
-
# Optionally raise an error or exit here if running as a script
|
| 61 |
-
# raise RuntimeError("No models loaded successfully!")
|
| 62 |
-
# exit() # Or sys.exit(1) after importing sys
|
| 63 |
-
|
| 64 |
-
# --- Update Model Options based on loaded models ---
|
| 65 |
-
# Original MODEL_OPTIONS_MAP structure from config.py (assuming it looks like this)
|
| 66 |
-
# Replace this with your actual definition from config.py if different
|
| 67 |
-
MODEL_OPTIONS_MAP = {
|
| 68 |
-
"label": get_text("Model", "模型"),
|
| 69 |
-
"name": "model",
|
| 70 |
-
"choices": [
|
| 71 |
-
# Populate this dynamically
|
| 72 |
-
],
|
| 73 |
-
"info": get_text("Select the model you want to use", "请选择需要使用的模型"),
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
# Populate choices dynamically
|
| 77 |
-
AVAILABLE_MODEL_OPTIONS = []
|
| 78 |
-
for name in loaded_models.keys():
|
| 79 |
-
# Use the name itself as the label, or define more descriptive labels
|
| 80 |
-
label = name.split('/')[-1] # Get 'ViLawT5_QAChatBot' etc. as label
|
| 81 |
-
AVAILABLE_MODEL_OPTIONS.append({"label": label, "value": name})
|
| 82 |
-
|
| 83 |
-
MODEL_OPTIONS_MAP["choices"] = AVAILABLE_MODEL_OPTIONS
|
| 84 |
-
|
| 85 |
-
# Update DEFAULT_SETTINGS to use the first available model
|
| 86 |
-
if AVAILABLE_MODEL_OPTIONS:
|
| 87 |
-
DEFAULT_SETTINGS['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
|
| 88 |
-
else:
|
| 89 |
-
# Handle the case where no models are loaded - set a default or handle error
|
| 90 |
-
DEFAULT_SETTINGS['model'] = None
|
| 91 |
-
print("Warning: No models loaded, model selection will be empty.")
|
| 92 |
-
|
| 93 |
-
# --- Gradio UI and Events ---
|
| 94 |
-
|
| 95 |
-
# Removed: dashscope.api_key = api_key
|
| 96 |
-
|
| 97 |
-
# Removed: format_history function (not needed for simple seq2seq input)
|
| 98 |
-
|
| 99 |
-
class Gradio_Events:
|
| 100 |
-
|
| 101 |
-
@staticmethod
|
| 102 |
-
def submit(state_value):
|
| 103 |
-
start_time = time.time()
|
| 104 |
-
history = state_value["conversation_contexts"][
|
| 105 |
-
state_value["conversation_id"]]["history"]
|
| 106 |
-
settings = state_value["conversation_contexts"][
|
| 107 |
-
state_value["conversation_id"]]["settings"]
|
| 108 |
-
# enable_thinking = state_value["conversation_contexts"][
|
| 109 |
-
# state_value["conversation_id"]]["enable_thinking"] # Keep if needed for UI, but generation logic changes
|
| 110 |
-
|
| 111 |
-
model_name = settings.get("model")
|
| 112 |
-
|
| 113 |
-
# Ensure a model is selected and loaded
|
| 114 |
-
if not model_name or model_name not in loaded_models:
|
| 115 |
-
error_msg = f"Error: Model '{model_name}' is not available or not selected."
|
| 116 |
-
print(error_msg)
|
| 117 |
-
history.append({
|
| 118 |
-
"role": "assistant",
|
| 119 |
-
"content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">{error_msg}</span>'}],
|
| 120 |
-
"key": str(uuid.uuid4()),
|
| 121 |
-
"header": "Error",
|
| 122 |
-
"loading": False,
|
| 123 |
-
"status": "error"
|
| 124 |
-
})
|
| 125 |
-
yield {
|
| 126 |
-
chatbot: gr.update(value=history),
|
| 127 |
-
state: gr.update(value=state_value),
|
| 128 |
-
}
|
| 129 |
-
return # Stop processing this submission
|
| 130 |
-
|
| 131 |
-
# Get the actual model and tokenizer objects
|
| 132 |
-
selected_model_info = loaded_models[model_name]
|
| 133 |
-
model = selected_model_info["model"]
|
| 134 |
-
tokenizer = selected_model_info["tokenizer"]
|
| 135 |
-
model_label = next((item['label'] for item in AVAILABLE_MODEL_OPTIONS if item['value'] == model_name), model_name)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# --- Prepare Input for Seq2Seq Model ---
|
| 139 |
-
# Use the last user message as input. Adjust if your models need specific formatting.
|
| 140 |
-
if len(history) < 1 or history[-1]["role"] != "user":
|
| 141 |
-
# This case should ideally not happen if submit is called after add_message
|
| 142 |
-
user_input = "Hello" # Default or fetch differently
|
| 143 |
-
print("Warning: Could not find the last user message, using default.")
|
| 144 |
-
else:
|
| 145 |
-
user_input = history[-1]["content"]
|
| 146 |
-
|
| 147 |
-
# Simple prompt format (adjust if needed for your specific models)
|
| 148 |
-
# Example: Some models might expect "question: <query>" or similar
|
| 149 |
-
prompt = f"question: {user_input}" # Adjust this format as needed!
|
| 150 |
-
print(f"Using model: {model_name}")
|
| 151 |
-
print(f"Input prompt: {prompt}")
|
| 152 |
-
|
| 153 |
-
# Add placeholder for assistant response
|
| 154 |
-
history.append({
|
| 155 |
-
"role":
|
| 156 |
-
"assistant",
|
| 157 |
-
"content": [],
|
| 158 |
-
"key":
|
| 159 |
-
str(uuid.uuid4()),
|
| 160 |
-
"header": model_label, # Use the label from options
|
| 161 |
-
"loading":
|
| 162 |
-
True,
|
| 163 |
-
"status":
|
| 164 |
-
"pending"
|
| 165 |
-
})
|
| 166 |
-
|
| 167 |
-
yield {
|
| 168 |
-
chatbot: gr.update(value=history),
|
| 169 |
-
state: gr.update(value=state_value),
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
try:
|
| 173 |
-
# --- Tokenize and Generate ---
|
| 174 |
-
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) # Adjust max_length
|
| 175 |
-
|
| 176 |
-
# Generation parameters (tune these for your models)
|
| 177 |
-
generation_kwargs = {
|
| 178 |
-
"max_length": 512, # Adjust max output length
|
| 179 |
-
"num_beams": 5, # Beam search
|
| 180 |
-
"early_stopping": True,
|
| 181 |
-
# Add other parameters like temperature, top_k, top_p if desired
|
| 182 |
-
# "temperature": 0.7,
|
| 183 |
-
# "top_k": 50,
|
| 184 |
-
}
|
| 185 |
-
print(f"Generating with kwargs: {generation_kwargs}")
|
| 186 |
-
|
| 187 |
-
with torch.no_grad(): # Important for inference
|
| 188 |
-
outputs = model.generate(**inputs, **generation_kwargs)
|
| 189 |
-
|
| 190 |
-
# --- Decode Response ---
|
| 191 |
-
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 192 |
-
print(f"Raw response: {response_text}")
|
| 193 |
-
|
| 194 |
-
# --- Update History ---
|
| 195 |
-
history[-1]["content"] = [{"type": "text", "content": response_text}]
|
| 196 |
-
history[-1]["loading"] = False
|
| 197 |
-
history[-1]["status"] = "done"
|
| 198 |
-
cost_time = "{:.2f}".format(time.time() - start_time)
|
| 199 |
-
history[-1]["footer"] = get_text(f"{cost_time}s", f"用时{cost_time}s")
|
| 200 |
-
|
| 201 |
-
yield {
|
| 202 |
-
chatbot: gr.update(value=history),
|
| 203 |
-
state: gr.update(value=state_value),
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
except Exception as e:
|
| 207 |
-
print(f"Error during generation with model {model_name}: {e}")
|
| 208 |
-
history[-1]["loading"] = False
|
| 209 |
-
history[-1]["status"] = "error" # Use 'error' status
|
| 210 |
-
history[-1]["content"] = [{
|
| 211 |
-
"type":
|
| 212 |
-
"text",
|
| 213 |
-
"content":
|
| 214 |
-
f'<span style="color: var(--color-red-500)">Error during generation: {str(e)}</span>'
|
| 215 |
-
}]
|
| 216 |
-
yield {
|
| 217 |
-
chatbot: gr.update(value=history),
|
| 218 |
-
state: gr.update(value=state_value)
|
| 219 |
-
}
|
| 220 |
-
# Re-raise if you want the error to propagate further, or handle it here
|
| 221 |
-
# raise e
|
| 222 |
-
|
| 223 |
-
@staticmethod
|
| 224 |
-
def add_message(input_value, settings_form_value, thinking_btn_state_value, # Keep thinking_btn_state if UI uses it
|
| 225 |
-
state_value):
|
| 226 |
-
if not input_value or input_value.strip() == "":
|
| 227 |
-
print("Empty input, skipping.")
|
| 228 |
-
# Optionally return an update to clear the input without submitting
|
| 229 |
-
# return { input: gr.update(value="") }
|
| 230 |
-
return gr.skip() # Skip the entire process if input is empty
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
if not state_value["conversation_id"]:
|
| 234 |
-
random_id = str(uuid.uuid4())
|
| 235 |
-
history = []
|
| 236 |
-
state_value["conversation_id"] = random_id
|
| 237 |
-
# Ensure default settings (including the default model) are applied
|
| 238 |
-
current_settings = settings_form_value if settings_form_value else DEFAULT_SETTINGS.copy()
|
| 239 |
-
if not current_settings.get('model') and AVAILABLE_MODEL_OPTIONS:
|
| 240 |
-
current_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
|
| 241 |
-
|
| 242 |
-
state_value["conversation_contexts"][
|
| 243 |
-
state_value["conversation_id"]] = {
|
| 244 |
-
"history": history,
|
| 245 |
-
"settings": current_settings, # Use current or default settings
|
| 246 |
-
"enable_thinking": thinking_btn_state_value["enable_thinking"] # Keep if needed
|
| 247 |
-
}
|
| 248 |
-
state_value["conversations"].append({
|
| 249 |
-
"label": input_value[:50] + ('...' if len(input_value) > 50 else ''), # Truncate label
|
| 250 |
-
"key": random_id
|
| 251 |
-
})
|
| 252 |
-
else:
|
| 253 |
-
# Update settings for existing conversation before adding message
|
| 254 |
-
state_value["conversation_contexts"][
|
| 255 |
-
state_value["conversation_id"]]["settings"] = settings_form_value
|
| 256 |
-
state_value["conversation_contexts"][
|
| 257 |
-
state_value["conversation_id"]]["enable_thinking"] = thinking_btn_state_value["enable_thinking"]
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
history = state_value["conversation_contexts"][
|
| 261 |
-
state_value["conversation_id"]]["history"]
|
| 262 |
-
|
| 263 |
-
# Add user message
|
| 264 |
-
history.append({
|
| 265 |
-
"role": "user",
|
| 266 |
-
"content": input_value,
|
| 267 |
-
"key": str(uuid.uuid4())
|
| 268 |
-
})
|
| 269 |
-
|
| 270 |
-
# Update state *before* calling preprocess/submit
|
| 271 |
-
# No, preprocess needs the user message *already* in history
|
| 272 |
-
# state_value["conversation_contexts"][
|
| 273 |
-
# state_value["conversation_id"]]["history"] = history
|
| 274 |
-
|
| 275 |
-
yield Gradio_Events.preprocess_submit(clear_input=True)(state_value)
|
| 276 |
-
|
| 277 |
-
# Make sure the model is loaded before trying to submit
|
| 278 |
-
selected_model = state_value["conversation_contexts"][state_value["conversation_id"]]["settings"].get('model')
|
| 279 |
-
if not selected_model or selected_model not in loaded_models:
|
| 280 |
-
# Handle case where no model is selected or available *before* calling submit
|
| 281 |
-
error_msg = f"Error: Model '{selected_model}' not available or not selected. Cannot generate response."
|
| 282 |
-
print(error_msg)
|
| 283 |
-
history.append({
|
| 284 |
-
"role": "assistant",
|
| 285 |
-
"content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">{error_msg}</span>'}],
|
| 286 |
-
"key": str(uuid.uuid4()),
|
| 287 |
-
"header": "Error",
|
| 288 |
-
"loading": False,
|
| 289 |
-
"status": "error"
|
| 290 |
-
})
|
| 291 |
-
# Need to yield the error message *and* the postprocess state
|
| 292 |
-
post_process_update = Gradio_Events.postprocess_submit(state_value)
|
| 293 |
-
post_process_update[chatbot] = gr.update(value=history) # Add chatbot update
|
| 294 |
-
yield post_process_update
|
| 295 |
-
|
| 296 |
-
else:
|
| 297 |
-
# Proceed with generation if model is available
|
| 298 |
-
try:
|
| 299 |
-
# Use a generator pattern even though submit itself doesn't stream *chunks* anymore
|
| 300 |
-
# It still yields intermediate states (loading) and the final state
|
| 301 |
-
for update in Gradio_Events.submit(state_value):
|
| 302 |
-
yield update
|
| 303 |
-
except Exception as e:
|
| 304 |
-
# This exception might be caught inside submit already,
|
| 305 |
-
# but catch here just in case submit itself raises before yielding
|
| 306 |
-
print(f"Error during submission process: {e}")
|
| 307 |
-
# Manually create an error state if submit failed early
|
| 308 |
-
history = state_value["conversation_contexts"][state_value["conversation_id"]]["history"]
|
| 309 |
-
if not history or history[-1].get("role") != "assistant":
|
| 310 |
-
# Add error message if submit failed before adding assistant placeholder
|
| 311 |
-
history.append({
|
| 312 |
-
"role": "assistant",
|
| 313 |
-
"content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">Error: {e}</span>'}],
|
| 314 |
-
"key": str(uuid.uuid4()), "header": "Error", "loading": False, "status": "error"
|
| 315 |
-
})
|
| 316 |
-
else: # Add error to the loading message if it exists
|
| 317 |
-
history[-1]["loading"] = False
|
| 318 |
-
history[-1]["status"] = "error"
|
| 319 |
-
history[-1]["content"] = [{"type": "text", "content": f'<span style="color: var(--color-red-500)">Error: {e}</span>'}]
|
| 320 |
-
yield Gradio_Events.postprocess_submit(state_value) # Ensure UI is unlocked
|
| 321 |
-
# raise e # Optionally re-raise
|
| 322 |
-
finally:
|
| 323 |
-
# Ensure UI is always returned to a non-loading state
|
| 324 |
-
yield Gradio_Events.postprocess_submit(state_value)
|
| 325 |
-
|
| 326 |
-
@staticmethod
|
| 327 |
-
def preprocess_submit(clear_input=True):
|
| 328 |
-
|
| 329 |
-
def preprocess_submit_handler(state_value):
|
| 330 |
-
# Check if conversation_id is valid before accessing context
|
| 331 |
-
if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
|
| 332 |
-
print("Warning: Invalid conversation ID in preprocess_submit.")
|
| 333 |
-
# Handle gracefully, maybe skip update or return default state
|
| 334 |
-
return gr.skip()
|
| 335 |
-
|
| 336 |
-
history = state_value["conversation_contexts"][
|
| 337 |
-
state_value["conversation_id"]]["history"]
|
| 338 |
-
return {
|
| 339 |
-
**({
|
| 340 |
-
input:
|
| 341 |
-
gr.update(value="", interactive=False) # Clear and disable input
|
| 342 |
-
} if clear_input else {input: gr.update(interactive=False)}), # Just disable
|
| 343 |
-
conversations:
|
| 344 |
-
gr.update(active_key=state_value["conversation_id"],
|
| 345 |
-
items=list(
|
| 346 |
-
map(
|
| 347 |
-
lambda item: {
|
| 348 |
-
**item,
|
| 349 |
-
# Disable *all* other conversations during generation
|
| 350 |
-
"disabled": True # item["key"] != state_value["conversation_id"]
|
| 351 |
-
}, state_value["conversations"]))),
|
| 352 |
-
add_conversation_btn:
|
| 353 |
-
gr.update(disabled=True),
|
| 354 |
-
clear_btn:
|
| 355 |
-
gr.update(disabled=True),
|
| 356 |
-
conversation_delete_menu_item:
|
| 357 |
-
gr.update(disabled=True),
|
| 358 |
-
# Ensure settings cannot be changed during generation
|
| 359 |
-
setting_btn: gr.update(disabled=True),
|
| 360 |
-
# Disable chatbot actions during generation
|
| 361 |
-
chatbot:
|
| 362 |
-
gr.update(value=history,
|
| 363 |
-
bot_config=bot_config(
|
| 364 |
-
disabled_actions=['edit', 'retry', 'delete']),
|
| 365 |
-
user_config=user_config(
|
| 366 |
-
disabled_actions=['edit', 'delete'])),
|
| 367 |
-
state:
|
| 368 |
-
gr.update(value=state_value), # Pass state through
|
| 369 |
-
}
|
| 370 |
-
|
| 371 |
-
return preprocess_submit_handler
|
| 372 |
-
|
| 373 |
-
@staticmethod
|
| 374 |
-
def postprocess_submit(state_value):
|
| 375 |
-
# Check if conversation_id is valid before accessing context
|
| 376 |
-
if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
|
| 377 |
-
print("Warning: Invalid conversation ID in postprocess_submit.")
|
| 378 |
-
# Return a state that enables controls but maybe shows no chat
|
| 379 |
-
return {
|
| 380 |
-
input: gr.update(interactive=True),
|
| 381 |
-
conversation_delete_menu_item: gr.update(disabled=True), # No active convo
|
| 382 |
-
clear_btn: gr.update(disabled=True), # No active convo
|
| 383 |
-
conversations: gr.update(items=state_value.get("conversations", [])), # Show list
|
| 384 |
-
add_conversation_btn: gr.update(disabled=False),
|
| 385 |
-
setting_btn: gr.update(disabled=False), # Re-enable settings button
|
| 386 |
-
chatbot: gr.update(value=None, bot_config=bot_config(), user_config=user_config()), # Clear chat
|
| 387 |
-
state: gr.update(value=state_value),
|
| 388 |
-
}
|
| 389 |
-
|
| 390 |
-
history = state_value["conversation_contexts"][
|
| 391 |
-
state_value["conversation_id"]]["history"]
|
| 392 |
-
return {
|
| 393 |
-
input:
|
| 394 |
-
gr.update(interactive=True), # Re-enable input
|
| 395 |
-
conversation_delete_menu_item:
|
| 396 |
-
gr.update(disabled=False),
|
| 397 |
-
clear_btn:
|
| 398 |
-
gr.update(disabled=False),
|
| 399 |
-
conversations: # Re-enable all conversations in the list
|
| 400 |
-
gr.update(items=list(map(lambda item: {**item, "disabled": False}, state_value["conversations"]))),
|
| 401 |
-
add_conversation_btn:
|
| 402 |
-
gr.update(disabled=False),
|
| 403 |
-
setting_btn: gr.update(disabled=False), # Re-enable settings button
|
| 404 |
-
chatbot:
|
| 405 |
-
gr.update(value=history,
|
| 406 |
-
bot_config=bot_config(),
|
| 407 |
-
user_config=user_config()), # Re-enable chatbot actions
|
| 408 |
-
state:
|
| 409 |
-
gr.update(value=state_value), # Pass state through
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
@staticmethod
|
| 413 |
-
def cancel(state_value):
|
| 414 |
-
# Since generation is not streamed chunk-by-chunk, cancel primarily means
|
| 415 |
-
# unlocking the UI if it got stuck somehow.
|
| 416 |
-
# The actual model generation might continue in the background if started.
|
| 417 |
-
# For true cancellation, you'd need more complex process management.
|
| 418 |
-
print("Cancel requested. Unlocking UI.")
|
| 419 |
-
# Find the last message, mark it as cancelled if it was loading
|
| 420 |
-
if state_value["conversation_id"] and state_value["conversation_id"] in state_value["conversation_contexts"]:
|
| 421 |
-
history = state_value["conversation_contexts"][state_value["conversation_id"]]["history"]
|
| 422 |
-
if history and history[-1].get("loading"):
|
| 423 |
-
history[-1]["loading"] = False
|
| 424 |
-
history[-1]["status"] = "cancelled" # Or 'error' or 'done'
|
| 425 |
-
history[-1]["footer"] = get_text("Generation cancelled by user", "用户取消生成")
|
| 426 |
-
# Optionally clear the content or leave it empty
|
| 427 |
-
# history[-1]["content"] = [{"type": "text", "content": "[Cancelled]"}]
|
| 428 |
-
# Return the postprocess state to unlock UI elements
|
| 429 |
-
return Gradio_Events.postprocess_submit(state_value)
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
@staticmethod
|
| 433 |
-
def delete_message(state_value, e: gr.EventData):
|
| 434 |
-
index = e._data["payload"][0]["index"]
|
| 435 |
-
if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
|
| 436 |
-
return gr.skip() # No active conversation
|
| 437 |
-
|
| 438 |
-
history = state_value["conversation_contexts"][
|
| 439 |
-
state_value["conversation_id"]]["history"]
|
| 440 |
-
# Make sure index is valid
|
| 441 |
-
if 0 <= index < len(history):
|
| 442 |
-
history.pop(index) # Use pop for efficiency
|
| 443 |
-
state_value["conversation_contexts"][
|
| 444 |
-
state_value["conversation_id"]]["history"] = history
|
| 445 |
-
else:
|
| 446 |
-
print(f"Warning: Invalid index {index} for deleting message.")
|
| 447 |
-
return gr.skip()
|
| 448 |
-
|
| 449 |
-
# Return only the state update, chatbot will refresh based on state
|
| 450 |
-
return gr.update(value=state_value)
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
@staticmethod
|
| 454 |
-
def edit_message(state_value, chatbot_value, e: gr.EventData):
|
| 455 |
-
index = e._data["payload"][0]["index"]
|
| 456 |
-
if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
|
| 457 |
-
return gr.skip() # No active conversation
|
| 458 |
-
|
| 459 |
-
history = state_value["conversation_contexts"][
|
| 460 |
-
state_value["conversation_id"]]["history"]
|
| 461 |
-
|
| 462 |
-
# Check index validity and if chatbot_value structure matches
|
| 463 |
-
if 0 <= index < len(history) and index < len(chatbot_value) and "content" in chatbot_value[index]:
|
| 464 |
-
# Update content based on the structure from the chatbot component
|
| 465 |
-
# It might be just text or a list of dicts like {"type": "text", "content": ...}
|
| 466 |
-
new_content = chatbot_value[index]["content"]
|
| 467 |
-
# Ensure history stores it in the expected format (likely just the text for user messages)
|
| 468 |
-
if history[index]["role"] == "user":
|
| 469 |
-
history[index]["content"] = new_content # Assuming user content is stored as a simple string
|
| 470 |
-
else:
|
| 471 |
-
# If assistant content is stored differently (e.g., list of dicts), adapt here
|
| 472 |
-
history[index]["content"] = new_content
|
| 473 |
-
state_value["conversation_contexts"][
|
| 474 |
-
state_value["conversation_id"]]["history"] = history
|
| 475 |
-
else:
|
| 476 |
-
print(f"Warning: Invalid index {index} or mismatch in chatbot_value structure for editing.")
|
| 477 |
-
return gr.skip()
|
| 478 |
-
|
| 479 |
-
return gr.update(value=state_value) # Return updated state
|
| 480 |
-
|
| 481 |
-
@staticmethod
|
| 482 |
-
def regenerate_message(settings_form_value, thinking_btn_state_value,
|
| 483 |
-
state_value, e: gr.EventData):
|
| 484 |
-
index = e._data["payload"][0]["index"]
|
| 485 |
-
if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
|
| 486 |
-
return gr.skip()
|
| 487 |
-
|
| 488 |
-
history = state_value["conversation_contexts"][
|
| 489 |
-
state_value["conversation_id"]]["history"]
|
| 490 |
-
|
| 491 |
-
# Find the user message preceding the assistant message at 'index'
|
| 492 |
-
# Usually, the message to regenerate is assistant, so the input is at index-1
|
| 493 |
-
if index > 0 and history[index]["role"] == "assistant" and history[index-1]["role"] == "user":
|
| 494 |
-
# Trim history up to *before* the assistant message we want to regenerate
|
| 495 |
-
history = history[:index]
|
| 496 |
-
else:
|
| 497 |
-
print("Warning: Cannot regenerate. Expected user message before the selected assistant message.")
|
| 498 |
-
# Fallback: Maybe just remove the selected message and the one before it?
|
| 499 |
-
# Or just remove the selected one and try submitting the last user message again?
|
| 500 |
-
# Safest: just skip regeneration if structure isn't as expected.
|
| 501 |
-
return gr.skip()
|
| 502 |
-
|
| 503 |
-
# Update state with trimmed history and current settings
|
| 504 |
-
state_value["conversation_contexts"][
|
| 505 |
-
state_value["conversation_id"]] = {
|
| 506 |
-
"history": history,
|
| 507 |
-
"settings": settings_form_value,
|
| 508 |
-
"enable_thinking": thinking_btn_state_value["enable_thinking"]
|
| 509 |
-
}
|
| 510 |
-
|
| 511 |
-
# Preprocess UI (lock controls, show loading state potentially)
|
| 512 |
-
# Preprocess needs the user message back in history to display correctly
|
| 513 |
-
# Let's yield preprocess first, then submit
|
| 514 |
-
yield Gradio_Events.preprocess_submit(clear_input=False)(state_value) # Don't clear input field
|
| 515 |
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
selected_suggestion = e._data["payload"][0]
|
| 564 |
-
# Simple replacement logic:
|
| 565 |
-
# Find the last '/' and replace everything after it, or append if no '/'
|
| 566 |
-
last_slash = input_value.rfind('/')
|
| 567 |
-
if last_slash != -1:
|
| 568 |
-
new_value = input_value[:last_slash] + selected_suggestion
|
| 569 |
-
else:
|
| 570 |
-
new_value = input_value + selected_suggestion # Or just selected_suggestion?
|
| 571 |
-
|
| 572 |
-
# Original logic was: input_value = input_value[:-1] + e._data["payload"][0]
|
| 573 |
-
# This assumes the trigger was the *last* character. Let's stick to that.
|
| 574 |
-
if input_value.endswith('/'):
|
| 575 |
-
new_value = input_value[:-1] + selected_suggestion
|
| 576 |
-
else:
|
| 577 |
-
new_value = selected_suggestion # Or append? Let's try replacing if no trailing /
|
| 578 |
-
|
| 579 |
-
return gr.update(value=new_value)
|
| 580 |
-
|
| 581 |
-
@staticmethod
|
| 582 |
-
def apply_prompt(e: gr.EventData):
|
| 583 |
-
# Gets value from welcome message prompt selection
|
| 584 |
-
return gr.update(value=e._data["payload"][0]["value"]["description"])
|
| 585 |
-
|
| 586 |
-
@staticmethod
|
| 587 |
-
def new_chat(thinking_btn_state, state_value):
|
| 588 |
-
if not state_value.get("conversation_id"): # Check if key exists
|
| 589 |
-
# If already on a new chat (no ID), do nothing
|
| 590 |
-
return gr.skip()
|
| 591 |
-
|
| 592 |
-
# Reset conversation ID and potentially thinking state
|
| 593 |
-
state_value["conversation_id"] = ""
|
| 594 |
-
thinking_btn_state["enable_thinking"] = True # Reset thinking state if used
|
| 595 |
-
|
| 596 |
-
# Prepare default settings for the new chat
|
| 597 |
-
new_chat_settings = DEFAULT_SETTINGS.copy()
|
| 598 |
-
if AVAILABLE_MODEL_OPTIONS and not new_chat_settings.get('model'):
|
| 599 |
-
new_chat_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
# Update UI: clear chatbot, select no active conversation, reset settings form
|
| 603 |
-
return gr.update(active_key=None), \
|
| 604 |
-
gr.update(value=None), \
|
| 605 |
-
gr.update(value=new_chat_settings), \
|
| 606 |
-
gr.update(value=thinking_btn_state), \
|
| 607 |
-
gr.update(value=state_value)
|
| 608 |
-
|
| 609 |
-
@staticmethod
|
| 610 |
-
def select_conversation(thinking_btn_state_value, state_value,
|
| 611 |
-
e: gr.EventData):
|
| 612 |
-
active_key = e._data["payload"][0]
|
| 613 |
-
current_id = state_value.get("conversation_id")
|
| 614 |
-
|
| 615 |
-
if current_id == active_key or not active_key or (
|
| 616 |
-
active_key not in state_value.get("conversation_contexts", {})):
|
| 617 |
-
print(f"Skipping conversation selection: current={current_id}, target={active_key}")
|
| 618 |
-
return gr.skip() # No change or invalid key
|
| 619 |
-
|
| 620 |
-
print(f"Switching conversation from '{current_id}' to '{active_key}'")
|
| 621 |
-
state_value["conversation_id"] = active_key
|
| 622 |
-
context = state_value["conversation_contexts"][active_key]
|
| 623 |
-
|
| 624 |
-
# Restore thinking state and settings from the selected conversation
|
| 625 |
-
thinking_btn_state_value["enable_thinking"] = context.get("enable_thinking", True) # Default to True if missing
|
| 626 |
-
restored_settings = context.get("settings", DEFAULT_SETTINGS.copy())
|
| 627 |
-
|
| 628 |
-
# Ensure the model in settings is still valid/loaded
|
| 629 |
-
if restored_settings.get('model') not in loaded_models:
|
| 630 |
-
print(f"Warning: Model '{restored_settings.get('model')}' in selected conversation is no longer loaded. Resetting to default.")
|
| 631 |
-
restored_settings['model'] = DEFAULT_SETTINGS.get('model') # Use current default
|
| 632 |
-
|
| 633 |
-
# Update UI components
|
| 634 |
-
return gr.update(active_key=active_key), \
|
| 635 |
-
gr.update(value=context.get("history", [])), \
|
| 636 |
-
gr.update(value=restored_settings), \
|
| 637 |
-
gr.update(value=thinking_btn_state_value), \
|
| 638 |
-
gr.update(value=state_value) # Update the main state
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
@staticmethod
|
| 642 |
-
def click_conversation_menu(state_value, e: gr.EventData):
|
| 643 |
-
payload = e._data["payload"]
|
| 644 |
-
if not payload or len(payload) < 2:
|
| 645 |
-
print("Warning: Invalid payload for conversation menu click.")
|
| 646 |
-
return gr.skip()
|
| 647 |
-
|
| 648 |
-
conversation_id = payload[0].get("key")
|
| 649 |
-
operation = payload[1].get("key")
|
| 650 |
-
|
| 651 |
-
if not conversation_id or not operation:
|
| 652 |
-
print("Warning: Missing key or operation in conversation menu click.")
|
| 653 |
-
return gr.skip()
|
| 654 |
-
|
| 655 |
-
if operation == "delete":
|
| 656 |
-
print(f"Deleting conversation: {conversation_id}")
|
| 657 |
-
if conversation_id in state_value.get("conversation_contexts", {}):
|
| 658 |
-
del state_value["conversation_contexts"][conversation_id]
|
| 659 |
-
|
| 660 |
-
state_value["conversations"] = [
|
| 661 |
-
item for item in state_value.get("conversations", [])
|
| 662 |
-
if item.get("key") != conversation_id
|
| 663 |
-
]
|
| 664 |
-
|
| 665 |
-
# If the deleted conversation was the active one, clear the chat view
|
| 666 |
-
if state_value.get("conversation_id") == conversation_id:
|
| 667 |
-
state_value["conversation_id"] = ""
|
| 668 |
-
# Prepare default settings for the now empty view
|
| 669 |
-
new_chat_settings = DEFAULT_SETTINGS.copy()
|
| 670 |
-
if AVAILABLE_MODEL_OPTIONS and not new_chat_settings.get('model'):
|
| 671 |
-
new_chat_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
|
| 672 |
-
|
| 673 |
-
return gr.update(
|
| 674 |
-
items=state_value["conversations"],
|
| 675 |
-
active_key=None # No active key
|
| 676 |
-
), gr.update(value=None), gr.update(value=new_chat_settings), gr.update(value=state_value) # Added settings update
|
| 677 |
-
else:
|
| 678 |
-
# Just update the list of conversations, keep the current view
|
| 679 |
-
return gr.update(
|
| 680 |
-
items=state_value["conversations"]
|
| 681 |
-
), gr.skip(), gr.skip(), gr.update(value=state_value) # Skip chatbot/settings update
|
| 682 |
-
# Add other operations like 'rename' here if needed
|
| 683 |
-
# elif operation == "rename":
|
| 684 |
-
# ... implementation ...
|
| 685 |
-
|
| 686 |
-
return gr.skip() # Default skip if operation not handled
|
| 687 |
-
|
| 688 |
-
@staticmethod
|
| 689 |
-
def toggle_settings_header(settings_header_state_value):
|
| 690 |
-
settings_header_state_value[
|
| 691 |
-
"open"] = not settings_header_state_value.get("open", False) # Default to False if key missing
|
| 692 |
-
return gr.update(value=settings_header_state_value)
|
| 693 |
-
|
| 694 |
-
@staticmethod
|
| 695 |
-
def clear_conversation_history(state_value):
|
| 696 |
-
conversation_id = state_value.get("conversation_id")
|
| 697 |
-
if not conversation_id or conversation_id not in state_value.get("conversation_contexts", {}):
|
| 698 |
-
print("Skipping clear history: No active or valid conversation.")
|
| 699 |
-
return gr.skip() # No active conversation
|
| 700 |
-
|
| 701 |
-
print(f"Clearing history for conversation: {conversation_id}")
|
| 702 |
-
state_value["conversation_contexts"][conversation_id]["history"] = []
|
| 703 |
-
|
| 704 |
-
# Update chatbot display and the state
|
| 705 |
-
return gr.update(value=None), gr.update(value=state_value)
|
| 706 |
-
|
| 707 |
-
@staticmethod
|
| 708 |
-
def update_browser_state(state_value):
|
| 709 |
-
# Only save the necessary parts to browser state
|
| 710 |
-
return gr.update(value=dict(
|
| 711 |
-
conversations=state_value.get("conversations", []),
|
| 712 |
-
conversation_contexts=state_value.get("conversation_contexts", {})
|
| 713 |
-
# Do not save the active conversation_id itself, it's transient UI state
|
| 714 |
-
))
|
| 715 |
-
|
| 716 |
-
@staticmethod
|
| 717 |
-
def apply_browser_state(browser_state_value, state_value):
|
| 718 |
-
if not browser_state_value: # Handle initial load where state might be null/empty
|
| 719 |
-
print("No browser state found to apply.")
|
| 720 |
-
# Initialize state if empty
|
| 721 |
-
if not state_value.get("conversations"):
|
| 722 |
-
state_value["conversations"] = []
|
| 723 |
-
if not state_value.get("conversation_contexts"):
|
| 724 |
-
state_value["conversation_contexts"] = {}
|
| 725 |
-
state_value["conversation_id"] = "" # Ensure no active conversation on fresh load
|
| 726 |
-
# Prepare default settings for the initial view
|
| 727 |
-
initial_settings = DEFAULT_SETTINGS.copy()
|
| 728 |
-
if AVAILABLE_MODEL_OPTIONS and not initial_settings.get('model'):
|
| 729 |
-
initial_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
|
| 730 |
-
|
| 731 |
-
return gr.update(items=[]), gr.update(value=None), gr.update(value=initial_settings), gr.update(value=state_value)
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
print("Applying browser state...")
|
| 735 |
-
# Basic validation: check if keys exist and have expected types (list/dict)
|
| 736 |
-
loaded_conversations = browser_state_value.get("conversations")
|
| 737 |
-
loaded_contexts = browser_state_value.get("conversation_contexts")
|
| 738 |
-
|
| 739 |
-
if isinstance(loaded_conversations, list) and isinstance(loaded_contexts, dict):
|
| 740 |
-
state_value["conversations"] = loaded_conversations
|
| 741 |
-
state_value["conversation_contexts"] = loaded_contexts
|
| 742 |
-
state_value["conversation_id"] = "" # Reset active conversation on load
|
| 743 |
-
|
| 744 |
-
# Prepare default settings for the initial view after loading state
|
| 745 |
-
initial_settings = DEFAULT_SETTINGS.copy()
|
| 746 |
-
if AVAILABLE_MODEL_OPTIONS and not initial_settings.get('model'):
|
| 747 |
-
initial_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
# Update UI based on loaded state
|
| 751 |
-
return gr.update(items=loaded_conversations, active_key=None), \
|
| 752 |
-
gr.update(value=None), \
|
| 753 |
-
gr.update(value=initial_settings), \
|
| 754 |
-
gr.update(value=state_value)
|
| 755 |
-
else:
|
| 756 |
-
print("Warning: Invalid browser state format. Ignoring.")
|
| 757 |
-
# Initialize state as if no browser state was found
|
| 758 |
-
state_value["conversations"] = []
|
| 759 |
-
state_value["conversation_contexts"] = {}
|
| 760 |
-
state_value["conversation_id"] = ""
|
| 761 |
-
initial_settings = DEFAULT_SETTINGS.copy()
|
| 762 |
-
if AVAILABLE_MODEL_OPTIONS and not initial_settings.get('model'):
|
| 763 |
-
initial_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
|
| 764 |
-
|
| 765 |
-
return gr.update(items=[]), gr.update(value=None), gr.update(value=initial_settings), gr.update(value=state_value)
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
# --- UI Definition ---
|
| 769 |
-
css = """
|
| 770 |
-
/* ... (keep existing CSS) ... */
|
| 771 |
-
.gradio-container {
|
| 772 |
-
padding: 0 !important;
|
| 773 |
-
}
|
| 774 |
-
.gradio-container > main.fillable {
|
| 775 |
-
padding: 0 !important;
|
| 776 |
-
}
|
| 777 |
-
#chatbot {
|
| 778 |
-
height: calc(100vh - 21px - 16px); /* Adjust if header/footer height changes */
|
| 779 |
-
max-height: 1500px;
|
| 780 |
-
}
|
| 781 |
-
#chatbot .chatbot-conversations {
|
| 782 |
-
height: 100vh; /* Full height */
|
| 783 |
-
background-color: var(--ms-gr-ant-color-bg-layout);
|
| 784 |
-
padding-left: 4px;
|
| 785 |
-
padding-right: 4px;
|
| 786 |
-
display: flex; /* Use flexbox for vertical layout */
|
| 787 |
-
flex-direction: column; /* Stack children vertically */
|
| 788 |
-
}
|
| 789 |
-
#chatbot .chatbot-conversations .chatbot-conversations-list {
|
| 790 |
-
padding-left: 0;
|
| 791 |
-
padding-right: 0;
|
| 792 |
-
flex-grow: 1; /* Allow list to take remaining space */
|
| 793 |
-
overflow-y: auto; /* Add scroll if list is long */
|
| 794 |
-
}
|
| 795 |
-
#chatbot .chatbot-chat {
|
| 796 |
-
padding: 32px;
|
| 797 |
-
padding-bottom: 0;
|
| 798 |
-
height: 100%;
|
| 799 |
-
display: flex; /* Use flexbox */
|
| 800 |
-
flex-direction: column; /* Stack chat messages and input vertically */
|
| 801 |
-
}
|
| 802 |
-
@media (max-width: 768px) {
|
| 803 |
-
#chatbot .chatbot-chat {
|
| 804 |
-
padding: 16px; /* Add some padding on mobile */
|
| 805 |
-
padding-bottom: 0;
|
| 806 |
-
}
|
| 807 |
-
#chatbot .chatbot-conversations {
|
| 808 |
-
/* Consider hiding conversation list or making it a drawer on mobile */
|
| 809 |
-
}
|
| 810 |
-
}
|
| 811 |
-
#chatbot .chatbot-chat .chatbot-chat-messages {
|
| 812 |
-
flex: 1; /* Allow chat messages to take available space */
|
| 813 |
-
overflow-y: auto; /* Add scroll to messages */
|
| 814 |
-
}
|
| 815 |
-
#chatbot .setting-form-thinking-budget {
|
| 816 |
-
/* Keep or remove based on whether thinking budget is still relevant */
|
| 817 |
-
/* display: none; /* Example: Hide if not used */
|
| 818 |
-
}
|
| 819 |
-
/* Style for disabled input */
|
| 820 |
-
#input-sender textarea:disabled {
|
| 821 |
-
background-color: var(--ms-gr-ant-color-bg-container-disabled);
|
| 822 |
-
cursor: not-allowed;
|
| 823 |
-
}
|
| 824 |
-
"""
|
| 825 |
-
|
| 826 |
-
# Removed model_options_map_json and the JS function, as options are handled in Python now
|
| 827 |
-
|
| 828 |
-
with gr.Blocks(css=css, fill_width=True) as demo: # Removed js=js
|
| 829 |
-
# Initial state structure
|
| 830 |
-
state = gr.State({
|
| 831 |
-
"conversation_contexts": {},
|
| 832 |
-
"conversations": [],
|
| 833 |
-
"conversation_id": "",
|
| 834 |
-
})
|
| 835 |
-
|
| 836 |
-
with ms.Application(), antdx.XProvider(
|
| 837 |
-
theme=DEFAULT_THEME, locale=DEFAULT_LOCALE), ms.AutoLoading():
|
| 838 |
-
with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot"): # Use gutter 0 for closer columns
|
| 839 |
-
# Left Column
|
| 840 |
-
with antd.Col(md=dict(flex="0 0 260px", span=0), # Hide on smaller screens (md breakpoint)
|
| 841 |
-
xs=dict(span=0), # Explicitly hide on extra small
|
| 842 |
-
sm=dict(span=24, order=1, flex="0 0 260px"), # Show on small screens, potentially adjust layout/order
|
| 843 |
-
# Consider using a collapsible drawer for mobile instead
|
| 844 |
-
elem_classes="chatbot-conversations-col" # Add class for potential styling
|
| 845 |
-
):
|
| 846 |
-
with ms.Div(elem_classes="chatbot-conversations"): # This div now uses flex column from CSS
|
| 847 |
-
with antd.Flex(vertical=True,
|
| 848 |
-
gap="small",
|
| 849 |
-
# Removed elem_style=dict(height="100%") - parent div controls height
|
| 850 |
-
):
|
| 851 |
-
# Logo
|
| 852 |
-
Logo()
|
| 853 |
-
|
| 854 |
-
# New Conversation Button
|
| 855 |
-
with antd.Button(value=None,
|
| 856 |
-
color="primary",
|
| 857 |
-
variant="filled",
|
| 858 |
-
block=True) as add_conversation_btn:
|
| 859 |
-
ms.Text(get_text("New Conversation", "新建对话"))
|
| 860 |
-
with ms.Slot("icon"):
|
| 861 |
-
antd.Icon("PlusOutlined")
|
| 862 |
-
|
| 863 |
-
# Conversations List
|
| 864 |
-
with antdx.Conversations(
|
| 865 |
-
elem_classes="chatbot-conversations-list", # Takes remaining space
|
| 866 |
-
active_key="", # Start with no active key
|
| 867 |
-
items=[] # Initial items empty, loaded by state
|
| 868 |
-
) as conversations:
|
| 869 |
-
# Keep menu items definition
|
| 870 |
-
with ms.Slot('menu.items'):
|
| 871 |
-
with antd.Menu.Item(
|
| 872 |
-
label="Delete", key="delete",
|
| 873 |
-
danger=True
|
| 874 |
-
) as conversation_delete_menu_item:
|
| 875 |
-
with ms.Slot("icon"):
|
| 876 |
-
antd.Icon("DeleteOutlined")
|
| 877 |
-
# Right Column
|
| 878 |
-
with antd.Col(flex=1, # Takes remaining horizontal space
|
| 879 |
-
elem_style=dict(height="100%"), # Ensure it fills vertically
|
| 880 |
-
md=dict(span=24, order=0), # Adjust order for mobile if left col shown
|
| 881 |
-
xs=dict(span=24, order=0),
|
| 882 |
-
sm=dict(order=0)
|
| 883 |
-
):
|
| 884 |
-
with antd.Flex(vertical=True,
|
| 885 |
-
gap="small", # Gap between chatbot and sender
|
| 886 |
-
elem_classes="chatbot-chat"): # This flex controls vertical layout of chat+input
|
| 887 |
-
# Chatbot Display Area
|
| 888 |
-
chatbot = pro.Chatbot(elem_classes="chatbot-chat-messages", # Takes flexible space
|
| 889 |
-
# height=0, # Let flexbox control height
|
| 890 |
-
value = None, # Initial value empty, loaded by state
|
| 891 |
-
welcome_config=welcome_config(),
|
| 892 |
-
user_config=user_config(),
|
| 893 |
-
bot_config=bot_config())
|
| 894 |
-
|
| 895 |
-
# Input Area (Sender)
|
| 896 |
-
with antdx.Suggestion(
|
| 897 |
-
items=DEFAULT_SUGGESTIONS,
|
| 898 |
-
should_trigger="""(e, { onTrigger, onKeyDown }) => {
|
| 899 |
-
// Keep existing JS logic for suggestions
|
| 900 |
-
switch(e.key) {
|
| 901 |
-
case '/': onTrigger(); break;
|
| 902 |
-
case 'ArrowRight': case 'ArrowLeft': case 'ArrowUp': case 'ArrowDown': break;
|
| 903 |
-
default: onTrigger(false);
|
| 904 |
-
}
|
| 905 |
-
onKeyDown(e);
|
| 906 |
-
}""") as suggestion:
|
| 907 |
-
with ms.Slot("children"):
|
| 908 |
-
# Use elem_id for easier targeting if needed
|
| 909 |
-
with antdx.Sender(elem_id="input-sender",
|
| 910 |
-
placeholder=get_text(
|
| 911 |
-
"Enter \"/\" to get suggestions, Shift+Enter for newline",
|
| 912 |
-
"输入 \"/\" 获取提示,Shift+Enter 换行"),
|
| 913 |
-
# interactive=True # Default is True
|
| 914 |
-
) as input:
|
| 915 |
-
with ms.Slot("header"):
|
| 916 |
-
# Pass AVAILABLE_MODEL_OPTIONS to SettingsHeader
|
| 917 |
-
settings_header_state, settings_form = SettingsHeader(
|
| 918 |
-
model_options=AVAILABLE_MODEL_OPTIONS, # Pass available options
|
| 919 |
-
default_settings=DEFAULT_SETTINGS # Pass defaults
|
| 920 |
-
)
|
| 921 |
-
with ms.Slot("prefix"):
|
| 922 |
-
with antd.Flex(
|
| 923 |
-
gap=4,
|
| 924 |
-
wrap=True, # Allow wrapping on small screens
|
| 925 |
-
elem_style=dict(maxWidth='80vw') # Adjust max width
|
| 926 |
-
):
|
| 927 |
-
with antd.Button(
|
| 928 |
-
value=None, type="text"
|
| 929 |
-
) as setting_btn:
|
| 930 |
-
with ms.Slot("icon"): antd.Icon("SettingOutlined")
|
| 931 |
-
with antd.Button(
|
| 932 |
-
value=None, type="text"
|
| 933 |
-
) as clear_btn:
|
| 934 |
-
with ms.Slot("icon"): antd.Icon("ClearOutlined")
|
| 935 |
-
# Keep ThinkingButton if UI uses it, otherwise remove
|
| 936 |
-
thinking_btn_state = ThinkingButton()
|
| 937 |
-
|
| 938 |
-
# --- Event Handlers ---
|
| 939 |
-
|
| 940 |
-
# Browser State Handler (if enabled)
|
| 941 |
-
if save_history:
|
| 942 |
-
browser_state = gr.BrowserState(
|
| 943 |
-
# Define the structure expected from the browser
|
| 944 |
-
value={ "conversations": [], "conversation_contexts": {} },
|
| 945 |
-
storage_key="vi_legal_chat_demo_storage" # Use a unique key
|
| 946 |
)
|
| 947 |
-
# When Python state changes, update the browser state
|
| 948 |
-
state.change(fn=Gradio_Events.update_browser_state,
|
| 949 |
-
inputs=[state],
|
| 950 |
-
outputs=[browser_state],
|
| 951 |
-
queue=False) # Run immediately
|
| 952 |
-
|
| 953 |
-
# On page load, apply browser state to Python state and UI
|
| 954 |
-
# Note: Ensure outputs match what apply_browser_state returns
|
| 955 |
-
demo.load(fn=Gradio_Events.apply_browser_state,
|
| 956 |
-
inputs=[browser_state, state],
|
| 957 |
-
outputs=[conversations, chatbot, settings_form, state], # Outputs to update UI
|
| 958 |
-
queue=False) # Run immediately on load
|
| 959 |
-
elif not loaded_models:
|
| 960 |
-
# If history saving is off AND no models loaded, show a message
|
| 961 |
-
def show_no_model_warning():
|
| 962 |
-
gr.Warning("No models were loaded successfully. The application functionality will be limited.")
|
| 963 |
-
# You could also update a specific Gradio component to show the error
|
| 964 |
-
demo.load(fn=show_no_model_warning, inputs=[], outputs=[])
|
| 965 |
-
|
| 966 |
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
|
| 999 |
-
#
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
],
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
|
| 1011 |
-
# Input Handler
|
| 1012 |
-
submit_event = input.submit(
|
| 1013 |
-
fn=Gradio_Events.add_message,
|
| 1014 |
-
inputs=[input,
|
| 1015 |
-
settings_form, thinking_btn_state, state],
|
| 1016 |
-
outputs=[ # Outputs from preprocess, submit, and postprocess combined
|
| 1017 |
-
input, conversations, add_conversation_btn, clear_btn,
|
| 1018 |
-
conversation_delete_menu_item, setting_btn, chatbot, state
|
| 1019 |
-
]) # Ensure outputs match yields
|
| 1020 |
|
| 1021 |
-
|
| 1022 |
-
input.cancel(fn=Gradio_Events.cancel,
|
| 1023 |
-
inputs=[state],
|
| 1024 |
-
outputs=[ # Outputs matching postprocess_submit return dict keys
|
| 1025 |
-
input, conversation_delete_menu_item, clear_btn,
|
| 1026 |
-
conversations, add_conversation_btn, setting_btn, chatbot, state
|
| 1027 |
-
],
|
| 1028 |
-
cancels=[submit_event, regenerating_event], # Cancel ongoing submit/regen
|
| 1029 |
-
queue=False) # Run immediately
|
| 1030 |
-
|
| 1031 |
-
# Input Actions Handler
|
| 1032 |
-
setting_btn.click(fn=Gradio_Events.toggle_settings_header,
|
| 1033 |
-
inputs=[settings_header_state],
|
| 1034 |
-
outputs=[settings_header_state])
|
| 1035 |
-
clear_btn.click(fn=Gradio_Events.clear_conversation_history,
|
| 1036 |
-
inputs=[state],
|
| 1037 |
-
outputs=[chatbot, state]) # Update chatbot display and state
|
| 1038 |
-
suggestion.select(fn=Gradio_Events.select_suggestion,
|
| 1039 |
-
inputs=[input],
|
| 1040 |
-
outputs=[input]) # Update input field
|
| 1041 |
-
|
| 1042 |
-
# --- Launch ---
|
| 1043 |
if __name__ == "__main__":
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
print("Launching Gradio Interface...")
|
| 1051 |
-
demo.queue(default_concurrency_limit=10, # Adjust concurrency based on your GPU/CPU resources
|
| 1052 |
-
max_size=20).launch(ssr_mode=False, # Consider True if SEO or initial load speed is critical
|
| 1053 |
-
# share=True, # Uncomment to create a public link (use with caution)
|
| 1054 |
-
# server_name="0.0.0.0" # Uncomment to allow access from network
|
| 1055 |
-
max_threads=40 # Gradio default
|
| 1056 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| 3 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
# Kiểm tra thiết bị (GPU nếu có)
|
| 6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
print(f"Using device: {device}") # Thêm log để biết thiết bị đang sử dụng
|
|
|
|
| 8 |
|
| 9 |
+
# --- Tải mô hình và tokenizer ---
|
| 10 |
+
# Khởi tạo biến model và tokenizer là None
|
| 11 |
+
model_1, tokenizer_1 = None, None
|
| 12 |
+
model_2, tokenizer_2 = None, None
|
| 13 |
+
# model_3, tokenizer_3 = None, None # Không cần tải model_3 nữa
|
| 14 |
+
model_4, tokenizer_4 = None, None
|
| 15 |
|
| 16 |
# Sử dụng try-except để xử lý lỗi nếu không tải được mô hình
|
| 17 |
try:
|
|
|
|
| 19 |
print(f"Loading model: {model_name_1}...")
|
| 20 |
model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_1).to(device)
|
| 21 |
tokenizer_1 = AutoTokenizer.from_pretrained(model_name_1)
|
|
|
|
| 22 |
print(f"Model {model_name_1} loaded successfully.")
|
| 23 |
except Exception as e:
|
| 24 |
print(f"Error loading model {model_name_1}: {e}")
|
|
|
|
| 28 |
print(f"Loading model: {model_name_2}...")
|
| 29 |
model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_2).to(device)
|
| 30 |
tokenizer_2 = AutoTokenizer.from_pretrained(model_name_2)
|
|
|
|
| 31 |
print(f"Model {model_name_2} loaded successfully.")
|
| 32 |
except Exception as e:
|
| 33 |
print(f"Error loading model {model_name_2}: {e}")
|
| 34 |
|
| 35 |
# Bỏ qua việc tải model_3 (ViLawT5_RL)
|
| 36 |
+
# ... (phần code tải model_3 bị comment như cũ) ...
|
| 37 |
|
| 38 |
try:
|
| 39 |
model_name_4 = "sunbv56/V-LegalQA"
|
| 40 |
print(f"Loading model: {model_name_4}...")
|
| 41 |
model_4 = AutoModelForSeq2SeqLM.from_pretrained(model_name_4).to(device)
|
| 42 |
tokenizer_4 = AutoTokenizer.from_pretrained(model_name_4)
|
|
|
|
| 43 |
print(f"Model {model_name_4} loaded successfully.")
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Error loading model {model_name_4}: {e}")
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
# --- Hàm sinh phản hồi ---
|
| 49 |
+
def chatbot_response(question, model_choice, max_new_tokens, temperature, top_k, top_p, repetition_penalty, use_early_stopping, use_do_sample):
|
| 50 |
+
model = None
|
| 51 |
+
tokenizer = None
|
| 52 |
+
|
| 53 |
+
# Chọn model dựa trên lựa chọn của người dùng (đã bỏ ViLawT5_RL)
|
| 54 |
+
if model_choice == "ViLawT5" and model_1 and tokenizer_1:
|
| 55 |
+
model = model_1
|
| 56 |
+
tokenizer = tokenizer_1
|
| 57 |
+
elif model_choice == "ViT5" and model_2 and tokenizer_2:
|
| 58 |
+
model = model_2
|
| 59 |
+
tokenizer = tokenizer_2
|
| 60 |
+
|
| 61 |
+
# Bỏ điều kiện kiểm tra ViLawT5_RL
|
| 62 |
+
# elif model_choice == "ViLawT5_RL" and model_3 and tokenizer_3:
|
| 63 |
+
# model = model_3
|
| 64 |
+
# tokenizer = tokenizer_3
|
| 65 |
+
elif model_choice == "V-LegalQA" and model_4 and tokenizer_4:
|
| 66 |
+
model = model_4
|
| 67 |
+
tokenizer = tokenizer_4
|
| 68 |
+
else:
|
| 69 |
+
# Kiểm tra xem model có được tải không
|
| 70 |
+
available_models = []
|
| 71 |
+
if model_1: available_models.append("ViLawT5")
|
| 72 |
+
if model_2: available_models.append("ViT5")
|
| 73 |
+
# Không thêm ViLawT5_RL vào danh sách kiểm tra
|
| 74 |
+
if model_4: available_models.append("V-LegalQA")
|
| 75 |
+
|
| 76 |
+
if not available_models:
|
| 77 |
+
return "Error: No models were loaded successfully. Please check the logs."
|
| 78 |
+
if model_choice not in available_models:
|
| 79 |
+
return f"Error: Model '{model_choice}' was not loaded successfully or is invalid. Available models: {', '.join(available_models)}"
|
| 80 |
+
else: # Trường hợp model_choice hợp lệ nhưng model/tokenizer là None (lỗi không mong muốn)
|
| 81 |
+
return f"Error: An unexpected issue occurred with model '{model_choice}'. Please check the logs."
|
| 82 |
+
|
| 83 |
+
print(f"Generating response using {model_choice} with params: max_new_tokens={max_new_tokens}, temp={temperature}, top_k={top_k}, top_p={top_p}, rep_penalty={repetition_penalty}, early_stop={use_early_stopping}, do_sample={use_do_sample}")
|
| 84 |
+
|
| 85 |
+
input_text = f"câu_hỏi: {question}"
|
| 86 |
+
try:
|
| 87 |
+
data = tokenizer(
|
| 88 |
+
input_text,
|
| 89 |
+
return_tensors="pt",
|
| 90 |
+
truncation=True,
|
| 91 |
+
return_attention_mask=True,
|
| 92 |
+
add_special_tokens=True,
|
| 93 |
+
padding="max_length",
|
| 94 |
+
max_length=256 # Cân nhắc tăng max_length nếu câu hỏi/context dài
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
input_ids = data.input_ids.to(device)
|
| 98 |
+
attention_mask = data.attention_mask.to(device)
|
| 99 |
+
|
| 100 |
+
# Suy luận với mô hình
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
outputs = model.generate(
|
| 103 |
+
input_ids,
|
| 104 |
+
attention_mask=attention_mask,
|
| 105 |
+
max_new_tokens=int(max_new_tokens),
|
| 106 |
+
early_stopping=use_early_stopping,
|
| 107 |
+
do_sample=use_do_sample,
|
| 108 |
+
temperature=float(temperature),
|
| 109 |
+
top_k=int(top_k),
|
| 110 |
+
top_p=float(top_p),
|
| 111 |
+
repetition_penalty=float(repetition_penalty),
|
| 112 |
+
# Thêm pad_token_id nếu cần (thường không cần cho T5)
|
| 113 |
+
# pad_token_id=tokenizer.pad_token_id
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 117 |
+
print(f"Raw output shape: {outputs[0].shape}") # Log thêm shape
|
| 118 |
+
print(f"Decoded response: {response}")
|
| 119 |
+
return response
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"Error during generation: {e}")
|
| 122 |
+
# In thêm traceback để debug
|
| 123 |
+
import traceback
|
| 124 |
+
traceback.print_exc()
|
| 125 |
+
return f"An error occurred during response generation: {e}"
|
| 126 |
+
|
| 127 |
+
# --- Tạo danh sách các model đã tải thành công (bỏ ViLawT5_RL) ---
|
| 128 |
+
loaded_models = []
|
| 129 |
+
if model_1 and tokenizer_1: loaded_models.append("ViLawT5")
|
| 130 |
+
if model_2 and tokenizer_2: loaded_models.append("ViT5")
|
| 131 |
+
if model_4 and tokenizer_4: loaded_models.append("V-LegalQA")
|
| 132 |
+
|
| 133 |
+
# Chọn model mặc định
|
| 134 |
+
default_model = "V-LegalQA" if "V-LegalQA" in loaded_models else (loaded_models[0] if loaded_models else "No models available")
|
| 135 |
+
|
| 136 |
+
# --- Tạo giao diện với Gradio ---
|
| 137 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 138 |
+
gr.Markdown(
|
| 139 |
+
"""
|
| 140 |
+
# 🤖 AI Chatbot Pháp luật Việt Nam (Demo)
|
| 141 |
+
Chọn mô hình và đặt câu hỏi liên quan đến pháp luật.
|
| 142 |
+
Nhấn **Shift + Enter** để gửi câu hỏi, **Enter** để xuống dòng.
|
| 143 |
+
"""
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
with gr.Row():
|
| 147 |
+
model_choice = gr.Dropdown(
|
| 148 |
+
choices=loaded_models,
|
| 149 |
+
label="Chọn Mô hình AI",
|
| 150 |
+
value=default_model,
|
| 151 |
+
interactive=bool(loaded_models) # Chỉ cho phép tương tác nếu có model
|
| 152 |
+
)
|
| 153 |
|
| 154 |
+
# Đảm bảo 'lines' >= 2 để Shift+Enter có tác dụng rõ ràng
|
| 155 |
+
question_input = gr.Textbox(
|
| 156 |
+
label="Nhập câu hỏi của bạn (Shift+Enter để gửi)",
|
| 157 |
+
placeholder="Ví dụ: Thế nào là tội cố ý gây thương tích?",
|
| 158 |
+
lines=3, # Giữ nguyên hoặc tăng nếu muốn ô nhập cao hơn
|
| 159 |
+
# scale=7 # Ví dụ: làm cho ô nhập rộng hơn nếu cần
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# --- Cập nhật giá trị mặc định trong Accordion ---
|
| 163 |
+
with gr.Accordion("Tùy chọn Nâng cao (Generation Parameters)", open=False):
|
| 164 |
+
with gr.Row():
|
| 165 |
+
early_stopping_checkbox = gr.Checkbox(label="Enable Early Stopping", value=False, info="Dừng sớm khi gặp token EOS.")
|
| 166 |
+
do_sample_checkbox = gr.Checkbox(label="Enable Sampling (do_sample)", value=False, info="Sử dụng sampling (cần thiết cho temperature, top_k, top_p). Tắt nếu muốn greedy search.")
|
| 167 |
+
with gr.Row():
|
| 168 |
+
max_new_tokens_slider = gr.Slider(minimum=10, maximum=1024, value=512, step=10, label="Max New Tokens", info="Số lượng token tối đa được sinh ra.")
|
| 169 |
+
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature", info="Độ 'sáng tạo' của câu trả lời (thấp hơn = bảo thủ hơn). Cần bật do_sample.")
|
| 170 |
+
with gr.Row():
|
| 171 |
+
top_k_slider = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Top-K", info="Chỉ xem xét K token có xác suất cao nhất. Cần bật do_sample.")
|
| 172 |
+
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.01, label="Top-P (Nucleus Sampling)", info="Chỉ xem xét các token có tổng xác suất >= P. Cần bật do_sample.")
|
| 173 |
+
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Repetition Penalty", info="Phạt các token đã xuất hiện (cao hơn = ít lặp lại hơn).")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
response_output = gr.Textbox(label="Phản hồi của Chatbot", lines=5, interactive=False)
|
| 177 |
+
|
| 178 |
+
# Nút gửi vẫn giữ lại phòng trường hợp người dùng thích click hơn
|
| 179 |
+
submit_btn = gr.Button("Gửi câu hỏi", variant="primary")
|
| 180 |
+
|
| 181 |
+
# --- THAY ĐỔI QUAN TRỌNG ---
|
| 182 |
+
# Tạo một list các inputs để dùng chung cho cả nút bấm và nhấn Enter
|
| 183 |
+
chatbot_inputs = [
|
| 184 |
+
question_input,
|
| 185 |
+
model_choice,
|
| 186 |
+
max_new_tokens_slider,
|
| 187 |
+
temperature_slider,
|
| 188 |
+
top_k_slider,
|
| 189 |
+
top_p_slider,
|
| 190 |
+
repetition_penalty_slider,
|
| 191 |
+
early_stopping_checkbox,
|
| 192 |
+
do_sample_checkbox
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
# 1. Gửi khi nhấn nút
|
| 196 |
+
submit_btn.click(
|
| 197 |
+
fn=chatbot_response,
|
| 198 |
+
inputs=chatbot_inputs,
|
| 199 |
+
outputs=response_output
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# 2. Gửi khi nhấn Enter trong Textbox question_input
|
| 203 |
+
# Shift+Enter sẽ tự động xuống dòng (hành vi mặc định khi lines > 1)
|
| 204 |
+
question_input.submit(
|
| 205 |
+
fn=chatbot_response,
|
| 206 |
+
inputs=chatbot_inputs,
|
| 207 |
+
outputs=response_output
|
| 208 |
+
)
|
| 209 |
+
# -----------------------------
|
| 210 |
+
|
| 211 |
+
gr.Examples(
|
| 212 |
+
examples=[
|
| 213 |
+
["Hợp đồng vô hiệu khi nào?", "V-LegalQA"],
|
| 214 |
+
["Quyền và nghĩa vụ của người lao động là gì?", "ViT5"],
|
| 215 |
+
["Người dưới 18 tuổi có được ký hợp đồng lao động không?\nThời gian làm việc tối đa là bao lâu?", "V-LegalQA"] # Ví dụ multi-line
|
| 216 |
],
|
| 217 |
+
inputs=[question_input, model_choice]
|
| 218 |
+
)
|
|
|
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
# --- Chạy Gradio ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
if __name__ == "__main__":
|
| 223 |
+
if not loaded_models:
|
| 224 |
+
print("WARNING: No models were loaded successfully. The application might not function correctly.")
|
| 225 |
+
# Cân nhắc thêm: gr.Info("Không có mô hình nào được tải thành công!") trong Blocks
|
| 226 |
+
# Bật share=True nếu muốn tạo link chia sẻ tạm thời
|
| 227 |
+
demo.launch(debug=True, share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|