52100322-TruongBinhThuan commited on
Commit
89ab967
·
2 Parent(s): d36d08512dfc47

Merge branch 'main' of https://huggingface.co/spaces/sunbv56/V-LegalQA-Chatbot

Browse files
Files changed (1) hide show
  1. app.py +184 -1013
app.py CHANGED
@@ -1,23 +1,17 @@
1
- import uuid
2
- import time
3
- import json
4
- import gradio as gr
5
  import torch
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
- import modelscope_studio.components.antd as antd
8
- import modelscope_studio.components.antdx as antdx
9
- import modelscope_studio.components.base as ms
10
- import modelscope_studio.components.pro as pro
11
- # Removed: import dashscope
12
- from config import DEFAULT_LOCALE, DEFAULT_SETTINGS, DEFAULT_THEME, DEFAULT_SUGGESTIONS, save_history, get_text, user_config, bot_config, welcome_config #, api_key # Removed api_key
13
- # Removed: from dashscope import Generation
14
 
15
- # --- Model Loading ---
16
- print("Setting up device...")
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Using device: {device}")
19
 
20
- loaded_models = {}
 
 
 
 
 
21
 
22
  # Sử dụng try-except để xử lý lỗi nếu không tải được mô hình
23
  try:
@@ -25,7 +19,6 @@ try:
25
  print(f"Loading model: {model_name_1}...")
26
  model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_1).to(device)
27
  tokenizer_1 = AutoTokenizer.from_pretrained(model_name_1)
28
- loaded_models[model_name_1] = {"model": model_1, "tokenizer": tokenizer_1}
29
  print(f"Model {model_name_1} loaded successfully.")
30
  except Exception as e:
31
  print(f"Error loading model {model_name_1}: {e}")
@@ -35,1022 +28,200 @@ try:
35
  print(f"Loading model: {model_name_2}...")
36
  model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_2).to(device)
37
  tokenizer_2 = AutoTokenizer.from_pretrained(model_name_2)
38
- loaded_models[model_name_2] = {"model": model_2, "tokenizer": tokenizer_2}
39
  print(f"Model {model_name_2} loaded successfully.")
40
  except Exception as e:
41
  print(f"Error loading model {model_name_2}: {e}")
42
 
43
  # Bỏ qua việc tải model_3 (ViLawT5_RL)
 
44
 
45
  try:
46
  model_name_4 = "sunbv56/V-LegalQA"
47
  print(f"Loading model: {model_name_4}...")
48
  model_4 = AutoModelForSeq2SeqLM.from_pretrained(model_name_4).to(device)
49
  tokenizer_4 = AutoTokenizer.from_pretrained(model_name_4)
50
- loaded_models[model_name_4] = {"model": model_4, "tokenizer": tokenizer_4}
51
  print(f"Model {model_name_4} loaded successfully.")
52
  except Exception as e:
53
  print(f"Error loading model {model_name_4}: {e}")
54
 
55
- if not loaded_models:
56
- print("\n" + "="*50)
57
- print("FATAL ERROR: No models could be loaded. The application cannot run.")
58
- print("Please check model names, network connection, and available disk space.")
59
- print("="*50 + "\n")
60
- # Optionally raise an error or exit here if running as a script
61
- # raise RuntimeError("No models loaded successfully!")
62
- # exit() # Or sys.exit(1) after importing sys
63
-
64
- # --- Update Model Options based on loaded models ---
65
- # Original MODEL_OPTIONS_MAP structure from config.py (assuming it looks like this)
66
- # Replace this with your actual definition from config.py if different
67
- MODEL_OPTIONS_MAP = {
68
- "label": get_text("Model", "模型"),
69
- "name": "model",
70
- "choices": [
71
- # Populate this dynamically
72
- ],
73
- "info": get_text("Select the model you want to use", "请选择需要使用的模型"),
74
- }
75
-
76
- # Populate choices dynamically
77
- AVAILABLE_MODEL_OPTIONS = []
78
- for name in loaded_models.keys():
79
- # Use the name itself as the label, or define more descriptive labels
80
- label = name.split('/')[-1] # Get 'ViLawT5_QAChatBot' etc. as label
81
- AVAILABLE_MODEL_OPTIONS.append({"label": label, "value": name})
82
-
83
- MODEL_OPTIONS_MAP["choices"] = AVAILABLE_MODEL_OPTIONS
84
-
85
- # Update DEFAULT_SETTINGS to use the first available model
86
- if AVAILABLE_MODEL_OPTIONS:
87
- DEFAULT_SETTINGS['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
88
- else:
89
- # Handle the case where no models are loaded - set a default or handle error
90
- DEFAULT_SETTINGS['model'] = None
91
- print("Warning: No models loaded, model selection will be empty.")
92
-
93
- # --- Gradio UI and Events ---
94
-
95
- # Removed: dashscope.api_key = api_key
96
-
97
- # Removed: format_history function (not needed for simple seq2seq input)
98
-
99
- class Gradio_Events:
100
-
101
- @staticmethod
102
- def submit(state_value):
103
- start_time = time.time()
104
- history = state_value["conversation_contexts"][
105
- state_value["conversation_id"]]["history"]
106
- settings = state_value["conversation_contexts"][
107
- state_value["conversation_id"]]["settings"]
108
- # enable_thinking = state_value["conversation_contexts"][
109
- # state_value["conversation_id"]]["enable_thinking"] # Keep if needed for UI, but generation logic changes
110
-
111
- model_name = settings.get("model")
112
-
113
- # Ensure a model is selected and loaded
114
- if not model_name or model_name not in loaded_models:
115
- error_msg = f"Error: Model '{model_name}' is not available or not selected."
116
- print(error_msg)
117
- history.append({
118
- "role": "assistant",
119
- "content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">{error_msg}</span>'}],
120
- "key": str(uuid.uuid4()),
121
- "header": "Error",
122
- "loading": False,
123
- "status": "error"
124
- })
125
- yield {
126
- chatbot: gr.update(value=history),
127
- state: gr.update(value=state_value),
128
- }
129
- return # Stop processing this submission
130
-
131
- # Get the actual model and tokenizer objects
132
- selected_model_info = loaded_models[model_name]
133
- model = selected_model_info["model"]
134
- tokenizer = selected_model_info["tokenizer"]
135
- model_label = next((item['label'] for item in AVAILABLE_MODEL_OPTIONS if item['value'] == model_name), model_name)
136
-
137
-
138
- # --- Prepare Input for Seq2Seq Model ---
139
- # Use the last user message as input. Adjust if your models need specific formatting.
140
- if len(history) < 1 or history[-1]["role"] != "user":
141
- # This case should ideally not happen if submit is called after add_message
142
- user_input = "Hello" # Default or fetch differently
143
- print("Warning: Could not find the last user message, using default.")
144
- else:
145
- user_input = history[-1]["content"]
146
-
147
- # Simple prompt format (adjust if needed for your specific models)
148
- # Example: Some models might expect "question: <query>" or similar
149
- prompt = f"question: {user_input}" # Adjust this format as needed!
150
- print(f"Using model: {model_name}")
151
- print(f"Input prompt: {prompt}")
152
-
153
- # Add placeholder for assistant response
154
- history.append({
155
- "role":
156
- "assistant",
157
- "content": [],
158
- "key":
159
- str(uuid.uuid4()),
160
- "header": model_label, # Use the label from options
161
- "loading":
162
- True,
163
- "status":
164
- "pending"
165
- })
166
-
167
- yield {
168
- chatbot: gr.update(value=history),
169
- state: gr.update(value=state_value),
170
- }
171
-
172
- try:
173
- # --- Tokenize and Generate ---
174
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) # Adjust max_length
175
-
176
- # Generation parameters (tune these for your models)
177
- generation_kwargs = {
178
- "max_length": 512, # Adjust max output length
179
- "num_beams": 5, # Beam search
180
- "early_stopping": True,
181
- # Add other parameters like temperature, top_k, top_p if desired
182
- # "temperature": 0.7,
183
- # "top_k": 50,
184
- }
185
- print(f"Generating with kwargs: {generation_kwargs}")
186
-
187
- with torch.no_grad(): # Important for inference
188
- outputs = model.generate(**inputs, **generation_kwargs)
189
-
190
- # --- Decode Response ---
191
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
192
- print(f"Raw response: {response_text}")
193
-
194
- # --- Update History ---
195
- history[-1]["content"] = [{"type": "text", "content": response_text}]
196
- history[-1]["loading"] = False
197
- history[-1]["status"] = "done"
198
- cost_time = "{:.2f}".format(time.time() - start_time)
199
- history[-1]["footer"] = get_text(f"{cost_time}s", f"用时{cost_time}s")
200
-
201
- yield {
202
- chatbot: gr.update(value=history),
203
- state: gr.update(value=state_value),
204
- }
205
-
206
- except Exception as e:
207
- print(f"Error during generation with model {model_name}: {e}")
208
- history[-1]["loading"] = False
209
- history[-1]["status"] = "error" # Use 'error' status
210
- history[-1]["content"] = [{
211
- "type":
212
- "text",
213
- "content":
214
- f'<span style="color: var(--color-red-500)">Error during generation: {str(e)}</span>'
215
- }]
216
- yield {
217
- chatbot: gr.update(value=history),
218
- state: gr.update(value=state_value)
219
- }
220
- # Re-raise if you want the error to propagate further, or handle it here
221
- # raise e
222
-
223
- @staticmethod
224
- def add_message(input_value, settings_form_value, thinking_btn_state_value, # Keep thinking_btn_state if UI uses it
225
- state_value):
226
- if not input_value or input_value.strip() == "":
227
- print("Empty input, skipping.")
228
- # Optionally return an update to clear the input without submitting
229
- # return { input: gr.update(value="") }
230
- return gr.skip() # Skip the entire process if input is empty
231
-
232
-
233
- if not state_value["conversation_id"]:
234
- random_id = str(uuid.uuid4())
235
- history = []
236
- state_value["conversation_id"] = random_id
237
- # Ensure default settings (including the default model) are applied
238
- current_settings = settings_form_value if settings_form_value else DEFAULT_SETTINGS.copy()
239
- if not current_settings.get('model') and AVAILABLE_MODEL_OPTIONS:
240
- current_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
241
-
242
- state_value["conversation_contexts"][
243
- state_value["conversation_id"]] = {
244
- "history": history,
245
- "settings": current_settings, # Use current or default settings
246
- "enable_thinking": thinking_btn_state_value["enable_thinking"] # Keep if needed
247
- }
248
- state_value["conversations"].append({
249
- "label": input_value[:50] + ('...' if len(input_value) > 50 else ''), # Truncate label
250
- "key": random_id
251
- })
252
- else:
253
- # Update settings for existing conversation before adding message
254
- state_value["conversation_contexts"][
255
- state_value["conversation_id"]]["settings"] = settings_form_value
256
- state_value["conversation_contexts"][
257
- state_value["conversation_id"]]["enable_thinking"] = thinking_btn_state_value["enable_thinking"]
258
-
259
-
260
- history = state_value["conversation_contexts"][
261
- state_value["conversation_id"]]["history"]
262
-
263
- # Add user message
264
- history.append({
265
- "role": "user",
266
- "content": input_value,
267
- "key": str(uuid.uuid4())
268
- })
269
-
270
- # Update state *before* calling preprocess/submit
271
- # No, preprocess needs the user message *already* in history
272
- # state_value["conversation_contexts"][
273
- # state_value["conversation_id"]]["history"] = history
274
-
275
- yield Gradio_Events.preprocess_submit(clear_input=True)(state_value)
276
-
277
- # Make sure the model is loaded before trying to submit
278
- selected_model = state_value["conversation_contexts"][state_value["conversation_id"]]["settings"].get('model')
279
- if not selected_model or selected_model not in loaded_models:
280
- # Handle case where no model is selected or available *before* calling submit
281
- error_msg = f"Error: Model '{selected_model}' not available or not selected. Cannot generate response."
282
- print(error_msg)
283
- history.append({
284
- "role": "assistant",
285
- "content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">{error_msg}</span>'}],
286
- "key": str(uuid.uuid4()),
287
- "header": "Error",
288
- "loading": False,
289
- "status": "error"
290
- })
291
- # Need to yield the error message *and* the postprocess state
292
- post_process_update = Gradio_Events.postprocess_submit(state_value)
293
- post_process_update[chatbot] = gr.update(value=history) # Add chatbot update
294
- yield post_process_update
295
-
296
- else:
297
- # Proceed with generation if model is available
298
- try:
299
- # Use a generator pattern even though submit itself doesn't stream *chunks* anymore
300
- # It still yields intermediate states (loading) and the final state
301
- for update in Gradio_Events.submit(state_value):
302
- yield update
303
- except Exception as e:
304
- # This exception might be caught inside submit already,
305
- # but catch here just in case submit itself raises before yielding
306
- print(f"Error during submission process: {e}")
307
- # Manually create an error state if submit failed early
308
- history = state_value["conversation_contexts"][state_value["conversation_id"]]["history"]
309
- if not history or history[-1].get("role") != "assistant":
310
- # Add error message if submit failed before adding assistant placeholder
311
- history.append({
312
- "role": "assistant",
313
- "content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">Error: {e}</span>'}],
314
- "key": str(uuid.uuid4()), "header": "Error", "loading": False, "status": "error"
315
- })
316
- else: # Add error to the loading message if it exists
317
- history[-1]["loading"] = False
318
- history[-1]["status"] = "error"
319
- history[-1]["content"] = [{"type": "text", "content": f'<span style="color: var(--color-red-500)">Error: {e}</span>'}]
320
- yield Gradio_Events.postprocess_submit(state_value) # Ensure UI is unlocked
321
- # raise e # Optionally re-raise
322
- finally:
323
- # Ensure UI is always returned to a non-loading state
324
- yield Gradio_Events.postprocess_submit(state_value)
325
-
326
- @staticmethod
327
- def preprocess_submit(clear_input=True):
328
-
329
- def preprocess_submit_handler(state_value):
330
- # Check if conversation_id is valid before accessing context
331
- if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
332
- print("Warning: Invalid conversation ID in preprocess_submit.")
333
- # Handle gracefully, maybe skip update or return default state
334
- return gr.skip()
335
-
336
- history = state_value["conversation_contexts"][
337
- state_value["conversation_id"]]["history"]
338
- return {
339
- **({
340
- input:
341
- gr.update(value="", interactive=False) # Clear and disable input
342
- } if clear_input else {input: gr.update(interactive=False)}), # Just disable
343
- conversations:
344
- gr.update(active_key=state_value["conversation_id"],
345
- items=list(
346
- map(
347
- lambda item: {
348
- **item,
349
- # Disable *all* other conversations during generation
350
- "disabled": True # item["key"] != state_value["conversation_id"]
351
- }, state_value["conversations"]))),
352
- add_conversation_btn:
353
- gr.update(disabled=True),
354
- clear_btn:
355
- gr.update(disabled=True),
356
- conversation_delete_menu_item:
357
- gr.update(disabled=True),
358
- # Ensure settings cannot be changed during generation
359
- setting_btn: gr.update(disabled=True),
360
- # Disable chatbot actions during generation
361
- chatbot:
362
- gr.update(value=history,
363
- bot_config=bot_config(
364
- disabled_actions=['edit', 'retry', 'delete']),
365
- user_config=user_config(
366
- disabled_actions=['edit', 'delete'])),
367
- state:
368
- gr.update(value=state_value), # Pass state through
369
- }
370
-
371
- return preprocess_submit_handler
372
-
373
- @staticmethod
374
- def postprocess_submit(state_value):
375
- # Check if conversation_id is valid before accessing context
376
- if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
377
- print("Warning: Invalid conversation ID in postprocess_submit.")
378
- # Return a state that enables controls but maybe shows no chat
379
- return {
380
- input: gr.update(interactive=True),
381
- conversation_delete_menu_item: gr.update(disabled=True), # No active convo
382
- clear_btn: gr.update(disabled=True), # No active convo
383
- conversations: gr.update(items=state_value.get("conversations", [])), # Show list
384
- add_conversation_btn: gr.update(disabled=False),
385
- setting_btn: gr.update(disabled=False), # Re-enable settings button
386
- chatbot: gr.update(value=None, bot_config=bot_config(), user_config=user_config()), # Clear chat
387
- state: gr.update(value=state_value),
388
- }
389
-
390
- history = state_value["conversation_contexts"][
391
- state_value["conversation_id"]]["history"]
392
- return {
393
- input:
394
- gr.update(interactive=True), # Re-enable input
395
- conversation_delete_menu_item:
396
- gr.update(disabled=False),
397
- clear_btn:
398
- gr.update(disabled=False),
399
- conversations: # Re-enable all conversations in the list
400
- gr.update(items=list(map(lambda item: {**item, "disabled": False}, state_value["conversations"]))),
401
- add_conversation_btn:
402
- gr.update(disabled=False),
403
- setting_btn: gr.update(disabled=False), # Re-enable settings button
404
- chatbot:
405
- gr.update(value=history,
406
- bot_config=bot_config(),
407
- user_config=user_config()), # Re-enable chatbot actions
408
- state:
409
- gr.update(value=state_value), # Pass state through
410
- }
411
-
412
- @staticmethod
413
- def cancel(state_value):
414
- # Since generation is not streamed chunk-by-chunk, cancel primarily means
415
- # unlocking the UI if it got stuck somehow.
416
- # The actual model generation might continue in the background if started.
417
- # For true cancellation, you'd need more complex process management.
418
- print("Cancel requested. Unlocking UI.")
419
- # Find the last message, mark it as cancelled if it was loading
420
- if state_value["conversation_id"] and state_value["conversation_id"] in state_value["conversation_contexts"]:
421
- history = state_value["conversation_contexts"][state_value["conversation_id"]]["history"]
422
- if history and history[-1].get("loading"):
423
- history[-1]["loading"] = False
424
- history[-1]["status"] = "cancelled" # Or 'error' or 'done'
425
- history[-1]["footer"] = get_text("Generation cancelled by user", "用户取消生成")
426
- # Optionally clear the content or leave it empty
427
- # history[-1]["content"] = [{"type": "text", "content": "[Cancelled]"}]
428
- # Return the postprocess state to unlock UI elements
429
- return Gradio_Events.postprocess_submit(state_value)
430
-
431
-
432
- @staticmethod
433
- def delete_message(state_value, e: gr.EventData):
434
- index = e._data["payload"][0]["index"]
435
- if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
436
- return gr.skip() # No active conversation
437
-
438
- history = state_value["conversation_contexts"][
439
- state_value["conversation_id"]]["history"]
440
- # Make sure index is valid
441
- if 0 <= index < len(history):
442
- history.pop(index) # Use pop for efficiency
443
- state_value["conversation_contexts"][
444
- state_value["conversation_id"]]["history"] = history
445
- else:
446
- print(f"Warning: Invalid index {index} for deleting message.")
447
- return gr.skip()
448
-
449
- # Return only the state update, chatbot will refresh based on state
450
- return gr.update(value=state_value)
451
-
452
-
453
- @staticmethod
454
- def edit_message(state_value, chatbot_value, e: gr.EventData):
455
- index = e._data["payload"][0]["index"]
456
- if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
457
- return gr.skip() # No active conversation
458
-
459
- history = state_value["conversation_contexts"][
460
- state_value["conversation_id"]]["history"]
461
-
462
- # Check index validity and if chatbot_value structure matches
463
- if 0 <= index < len(history) and index < len(chatbot_value) and "content" in chatbot_value[index]:
464
- # Update content based on the structure from the chatbot component
465
- # It might be just text or a list of dicts like {"type": "text", "content": ...}
466
- new_content = chatbot_value[index]["content"]
467
- # Ensure history stores it in the expected format (likely just the text for user messages)
468
- if history[index]["role"] == "user":
469
- history[index]["content"] = new_content # Assuming user content is stored as a simple string
470
- else:
471
- # If assistant content is stored differently (e.g., list of dicts), adapt here
472
- history[index]["content"] = new_content
473
- state_value["conversation_contexts"][
474
- state_value["conversation_id"]]["history"] = history
475
- else:
476
- print(f"Warning: Invalid index {index} or mismatch in chatbot_value structure for editing.")
477
- return gr.skip()
478
-
479
- return gr.update(value=state_value) # Return updated state
480
-
481
- @staticmethod
482
- def regenerate_message(settings_form_value, thinking_btn_state_value,
483
- state_value, e: gr.EventData):
484
- index = e._data["payload"][0]["index"]
485
- if not state_value["conversation_id"] or state_value["conversation_id"] not in state_value["conversation_contexts"]:
486
- return gr.skip()
487
-
488
- history = state_value["conversation_contexts"][
489
- state_value["conversation_id"]]["history"]
490
-
491
- # Find the user message preceding the assistant message at 'index'
492
- # Usually, the message to regenerate is assistant, so the input is at index-1
493
- if index > 0 and history[index]["role"] == "assistant" and history[index-1]["role"] == "user":
494
- # Trim history up to *before* the assistant message we want to regenerate
495
- history = history[:index]
496
- else:
497
- print("Warning: Cannot regenerate. Expected user message before the selected assistant message.")
498
- # Fallback: Maybe just remove the selected message and the one before it?
499
- # Or just remove the selected one and try submitting the last user message again?
500
- # Safest: just skip regeneration if structure isn't as expected.
501
- return gr.skip()
502
-
503
- # Update state with trimmed history and current settings
504
- state_value["conversation_contexts"][
505
- state_value["conversation_id"]] = {
506
- "history": history,
507
- "settings": settings_form_value,
508
- "enable_thinking": thinking_btn_state_value["enable_thinking"]
509
- }
510
-
511
- # Preprocess UI (lock controls, show loading state potentially)
512
- # Preprocess needs the user message back in history to display correctly
513
- # Let's yield preprocess first, then submit
514
- yield Gradio_Events.preprocess_submit(clear_input=False)(state_value) # Don't clear input field
515
 
516
- # Make sure the model is loaded before trying to submit
517
- selected_model = state_value["conversation_contexts"][state_value["conversation_id"]]["settings"].get('model')
518
- if not selected_model or selected_model not in loaded_models:
519
- # Handle case where no model is selected or available *before* calling submit
520
- error_msg = f"Error: Model '{selected_model}' not available or not selected. Cannot regenerate response."
521
- print(error_msg)
522
- history.append({
523
- "role": "assistant",
524
- "content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">{error_msg}</span>'}],
525
- "key": str(uuid.uuid4()),
526
- "header": "Error",
527
- "loading": False,
528
- "status": "error"
529
- })
530
- post_process_update = Gradio_Events.postprocess_submit(state_value)
531
- post_process_update[chatbot] = gr.update(value=history) # Add chatbot update
532
- yield post_process_update
533
- else:
534
- # Call submit to generate the new response
535
- try:
536
- for chunk in Gradio_Events.submit(state_value):
537
- yield chunk
538
- except Exception as e:
539
- print(f"Error during regeneration submission: {e}")
540
- # Handle error display similar to add_message
541
- history = state_value["conversation_contexts"][state_value["conversation_id"]]["history"]
542
- if not history or history[-1].get("role") != "assistant":
543
- history.append({
544
- "role": "assistant",
545
- "content": [{"type": "text", "content": f'<span style="color: var(--color-red-500)">Error: {e}</span>'}],
546
- "key": str(uuid.uuid4()), "header": "Error", "loading": False, "status": "error"
547
- })
548
- else:
549
- history[-1]["loading"] = False; history[-1]["status"] = "error"
550
- history[-1]["content"] = [{"type": "text", "content": f'<span style="color: var(--color-red-500)">Error: {e}</span>'}]
551
- yield Gradio_Events.postprocess_submit(state_value)
552
- # raise e
553
- finally:
554
- # Postprocess UI (unlock controls)
555
- yield Gradio_Events.postprocess_submit(state_value)
556
-
557
- @staticmethod
558
- def select_suggestion(input_value, e: gr.EventData):
559
- # This assumes the suggestion replaces the '/' trigger
560
- # Adjust if the behavior should be different (e.g., append)
561
- # The original JS logic suggests '/' triggers the suggestion list
562
- # Selecting might append or replace based on context, let's assume replacement for simplicity
563
- selected_suggestion = e._data["payload"][0]
564
- # Simple replacement logic:
565
- # Find the last '/' and replace everything after it, or append if no '/'
566
- last_slash = input_value.rfind('/')
567
- if last_slash != -1:
568
- new_value = input_value[:last_slash] + selected_suggestion
569
- else:
570
- new_value = input_value + selected_suggestion # Or just selected_suggestion?
571
-
572
- # Original logic was: input_value = input_value[:-1] + e._data["payload"][0]
573
- # This assumes the trigger was the *last* character. Let's stick to that.
574
- if input_value.endswith('/'):
575
- new_value = input_value[:-1] + selected_suggestion
576
- else:
577
- new_value = selected_suggestion # Or append? Let's try replacing if no trailing /
578
-
579
- return gr.update(value=new_value)
580
-
581
- @staticmethod
582
- def apply_prompt(e: gr.EventData):
583
- # Gets value from welcome message prompt selection
584
- return gr.update(value=e._data["payload"][0]["value"]["description"])
585
-
586
- @staticmethod
587
- def new_chat(thinking_btn_state, state_value):
588
- if not state_value.get("conversation_id"): # Check if key exists
589
- # If already on a new chat (no ID), do nothing
590
- return gr.skip()
591
-
592
- # Reset conversation ID and potentially thinking state
593
- state_value["conversation_id"] = ""
594
- thinking_btn_state["enable_thinking"] = True # Reset thinking state if used
595
-
596
- # Prepare default settings for the new chat
597
- new_chat_settings = DEFAULT_SETTINGS.copy()
598
- if AVAILABLE_MODEL_OPTIONS and not new_chat_settings.get('model'):
599
- new_chat_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
600
-
601
-
602
- # Update UI: clear chatbot, select no active conversation, reset settings form
603
- return gr.update(active_key=None), \
604
- gr.update(value=None), \
605
- gr.update(value=new_chat_settings), \
606
- gr.update(value=thinking_btn_state), \
607
- gr.update(value=state_value)
608
-
609
- @staticmethod
610
- def select_conversation(thinking_btn_state_value, state_value,
611
- e: gr.EventData):
612
- active_key = e._data["payload"][0]
613
- current_id = state_value.get("conversation_id")
614
-
615
- if current_id == active_key or not active_key or (
616
- active_key not in state_value.get("conversation_contexts", {})):
617
- print(f"Skipping conversation selection: current={current_id}, target={active_key}")
618
- return gr.skip() # No change or invalid key
619
-
620
- print(f"Switching conversation from '{current_id}' to '{active_key}'")
621
- state_value["conversation_id"] = active_key
622
- context = state_value["conversation_contexts"][active_key]
623
-
624
- # Restore thinking state and settings from the selected conversation
625
- thinking_btn_state_value["enable_thinking"] = context.get("enable_thinking", True) # Default to True if missing
626
- restored_settings = context.get("settings", DEFAULT_SETTINGS.copy())
627
-
628
- # Ensure the model in settings is still valid/loaded
629
- if restored_settings.get('model') not in loaded_models:
630
- print(f"Warning: Model '{restored_settings.get('model')}' in selected conversation is no longer loaded. Resetting to default.")
631
- restored_settings['model'] = DEFAULT_SETTINGS.get('model') # Use current default
632
-
633
- # Update UI components
634
- return gr.update(active_key=active_key), \
635
- gr.update(value=context.get("history", [])), \
636
- gr.update(value=restored_settings), \
637
- gr.update(value=thinking_btn_state_value), \
638
- gr.update(value=state_value) # Update the main state
639
-
640
-
641
- @staticmethod
642
- def click_conversation_menu(state_value, e: gr.EventData):
643
- payload = e._data["payload"]
644
- if not payload or len(payload) < 2:
645
- print("Warning: Invalid payload for conversation menu click.")
646
- return gr.skip()
647
-
648
- conversation_id = payload[0].get("key")
649
- operation = payload[1].get("key")
650
-
651
- if not conversation_id or not operation:
652
- print("Warning: Missing key or operation in conversation menu click.")
653
- return gr.skip()
654
-
655
- if operation == "delete":
656
- print(f"Deleting conversation: {conversation_id}")
657
- if conversation_id in state_value.get("conversation_contexts", {}):
658
- del state_value["conversation_contexts"][conversation_id]
659
-
660
- state_value["conversations"] = [
661
- item for item in state_value.get("conversations", [])
662
- if item.get("key") != conversation_id
663
- ]
664
-
665
- # If the deleted conversation was the active one, clear the chat view
666
- if state_value.get("conversation_id") == conversation_id:
667
- state_value["conversation_id"] = ""
668
- # Prepare default settings for the now empty view
669
- new_chat_settings = DEFAULT_SETTINGS.copy()
670
- if AVAILABLE_MODEL_OPTIONS and not new_chat_settings.get('model'):
671
- new_chat_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
672
-
673
- return gr.update(
674
- items=state_value["conversations"],
675
- active_key=None # No active key
676
- ), gr.update(value=None), gr.update(value=new_chat_settings), gr.update(value=state_value) # Added settings update
677
- else:
678
- # Just update the list of conversations, keep the current view
679
- return gr.update(
680
- items=state_value["conversations"]
681
- ), gr.skip(), gr.skip(), gr.update(value=state_value) # Skip chatbot/settings update
682
- # Add other operations like 'rename' here if needed
683
- # elif operation == "rename":
684
- # ... implementation ...
685
-
686
- return gr.skip() # Default skip if operation not handled
687
-
688
- @staticmethod
689
- def toggle_settings_header(settings_header_state_value):
690
- settings_header_state_value[
691
- "open"] = not settings_header_state_value.get("open", False) # Default to False if key missing
692
- return gr.update(value=settings_header_state_value)
693
-
694
- @staticmethod
695
- def clear_conversation_history(state_value):
696
- conversation_id = state_value.get("conversation_id")
697
- if not conversation_id or conversation_id not in state_value.get("conversation_contexts", {}):
698
- print("Skipping clear history: No active or valid conversation.")
699
- return gr.skip() # No active conversation
700
-
701
- print(f"Clearing history for conversation: {conversation_id}")
702
- state_value["conversation_contexts"][conversation_id]["history"] = []
703
-
704
- # Update chatbot display and the state
705
- return gr.update(value=None), gr.update(value=state_value)
706
-
707
- @staticmethod
708
- def update_browser_state(state_value):
709
- # Only save the necessary parts to browser state
710
- return gr.update(value=dict(
711
- conversations=state_value.get("conversations", []),
712
- conversation_contexts=state_value.get("conversation_contexts", {})
713
- # Do not save the active conversation_id itself, it's transient UI state
714
- ))
715
-
716
- @staticmethod
717
- def apply_browser_state(browser_state_value, state_value):
718
- if not browser_state_value: # Handle initial load where state might be null/empty
719
- print("No browser state found to apply.")
720
- # Initialize state if empty
721
- if not state_value.get("conversations"):
722
- state_value["conversations"] = []
723
- if not state_value.get("conversation_contexts"):
724
- state_value["conversation_contexts"] = {}
725
- state_value["conversation_id"] = "" # Ensure no active conversation on fresh load
726
- # Prepare default settings for the initial view
727
- initial_settings = DEFAULT_SETTINGS.copy()
728
- if AVAILABLE_MODEL_OPTIONS and not initial_settings.get('model'):
729
- initial_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
730
-
731
- return gr.update(items=[]), gr.update(value=None), gr.update(value=initial_settings), gr.update(value=state_value)
732
-
733
-
734
- print("Applying browser state...")
735
- # Basic validation: check if keys exist and have expected types (list/dict)
736
- loaded_conversations = browser_state_value.get("conversations")
737
- loaded_contexts = browser_state_value.get("conversation_contexts")
738
-
739
- if isinstance(loaded_conversations, list) and isinstance(loaded_contexts, dict):
740
- state_value["conversations"] = loaded_conversations
741
- state_value["conversation_contexts"] = loaded_contexts
742
- state_value["conversation_id"] = "" # Reset active conversation on load
743
-
744
- # Prepare default settings for the initial view after loading state
745
- initial_settings = DEFAULT_SETTINGS.copy()
746
- if AVAILABLE_MODEL_OPTIONS and not initial_settings.get('model'):
747
- initial_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
748
-
749
-
750
- # Update UI based on loaded state
751
- return gr.update(items=loaded_conversations, active_key=None), \
752
- gr.update(value=None), \
753
- gr.update(value=initial_settings), \
754
- gr.update(value=state_value)
755
- else:
756
- print("Warning: Invalid browser state format. Ignoring.")
757
- # Initialize state as if no browser state was found
758
- state_value["conversations"] = []
759
- state_value["conversation_contexts"] = {}
760
- state_value["conversation_id"] = ""
761
- initial_settings = DEFAULT_SETTINGS.copy()
762
- if AVAILABLE_MODEL_OPTIONS and not initial_settings.get('model'):
763
- initial_settings['model'] = AVAILABLE_MODEL_OPTIONS[0]['value']
764
-
765
- return gr.update(items=[]), gr.update(value=None), gr.update(value=initial_settings), gr.update(value=state_value)
766
-
767
-
768
- # --- UI Definition ---
769
- css = """
770
- /* ... (keep existing CSS) ... */
771
- .gradio-container {
772
- padding: 0 !important;
773
- }
774
- .gradio-container > main.fillable {
775
- padding: 0 !important;
776
- }
777
- #chatbot {
778
- height: calc(100vh - 21px - 16px); /* Adjust if header/footer height changes */
779
- max-height: 1500px;
780
- }
781
- #chatbot .chatbot-conversations {
782
- height: 100vh; /* Full height */
783
- background-color: var(--ms-gr-ant-color-bg-layout);
784
- padding-left: 4px;
785
- padding-right: 4px;
786
- display: flex; /* Use flexbox for vertical layout */
787
- flex-direction: column; /* Stack children vertically */
788
- }
789
- #chatbot .chatbot-conversations .chatbot-conversations-list {
790
- padding-left: 0;
791
- padding-right: 0;
792
- flex-grow: 1; /* Allow list to take remaining space */
793
- overflow-y: auto; /* Add scroll if list is long */
794
- }
795
- #chatbot .chatbot-chat {
796
- padding: 32px;
797
- padding-bottom: 0;
798
- height: 100%;
799
- display: flex; /* Use flexbox */
800
- flex-direction: column; /* Stack chat messages and input vertically */
801
- }
802
- @media (max-width: 768px) {
803
- #chatbot .chatbot-chat {
804
- padding: 16px; /* Add some padding on mobile */
805
- padding-bottom: 0;
806
- }
807
- #chatbot .chatbot-conversations {
808
- /* Consider hiding conversation list or making it a drawer on mobile */
809
- }
810
- }
811
- #chatbot .chatbot-chat .chatbot-chat-messages {
812
- flex: 1; /* Allow chat messages to take available space */
813
- overflow-y: auto; /* Add scroll to messages */
814
- }
815
- #chatbot .setting-form-thinking-budget {
816
- /* Keep or remove based on whether thinking budget is still relevant */
817
- /* display: none; /* Example: Hide if not used */
818
- }
819
- /* Style for disabled input */
820
- #input-sender textarea:disabled {
821
- background-color: var(--ms-gr-ant-color-bg-container-disabled);
822
- cursor: not-allowed;
823
- }
824
- """
825
-
826
- # Removed model_options_map_json and the JS function, as options are handled in Python now
827
-
828
- with gr.Blocks(css=css, fill_width=True) as demo: # Removed js=js
829
- # Initial state structure
830
- state = gr.State({
831
- "conversation_contexts": {},
832
- "conversations": [],
833
- "conversation_id": "",
834
- })
835
-
836
- with ms.Application(), antdx.XProvider(
837
- theme=DEFAULT_THEME, locale=DEFAULT_LOCALE), ms.AutoLoading():
838
- with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot"): # Use gutter 0 for closer columns
839
- # Left Column
840
- with antd.Col(md=dict(flex="0 0 260px", span=0), # Hide on smaller screens (md breakpoint)
841
- xs=dict(span=0), # Explicitly hide on extra small
842
- sm=dict(span=24, order=1, flex="0 0 260px"), # Show on small screens, potentially adjust layout/order
843
- # Consider using a collapsible drawer for mobile instead
844
- elem_classes="chatbot-conversations-col" # Add class for potential styling
845
- ):
846
- with ms.Div(elem_classes="chatbot-conversations"): # This div now uses flex column from CSS
847
- with antd.Flex(vertical=True,
848
- gap="small",
849
- # Removed elem_style=dict(height="100%") - parent div controls height
850
- ):
851
- # Logo
852
- Logo()
853
-
854
- # New Conversation Button
855
- with antd.Button(value=None,
856
- color="primary",
857
- variant="filled",
858
- block=True) as add_conversation_btn:
859
- ms.Text(get_text("New Conversation", "新建对话"))
860
- with ms.Slot("icon"):
861
- antd.Icon("PlusOutlined")
862
-
863
- # Conversations List
864
- with antdx.Conversations(
865
- elem_classes="chatbot-conversations-list", # Takes remaining space
866
- active_key="", # Start with no active key
867
- items=[] # Initial items empty, loaded by state
868
- ) as conversations:
869
- # Keep menu items definition
870
- with ms.Slot('menu.items'):
871
- with antd.Menu.Item(
872
- label="Delete", key="delete",
873
- danger=True
874
- ) as conversation_delete_menu_item:
875
- with ms.Slot("icon"):
876
- antd.Icon("DeleteOutlined")
877
- # Right Column
878
- with antd.Col(flex=1, # Takes remaining horizontal space
879
- elem_style=dict(height="100%"), # Ensure it fills vertically
880
- md=dict(span=24, order=0), # Adjust order for mobile if left col shown
881
- xs=dict(span=24, order=0),
882
- sm=dict(order=0)
883
- ):
884
- with antd.Flex(vertical=True,
885
- gap="small", # Gap between chatbot and sender
886
- elem_classes="chatbot-chat"): # This flex controls vertical layout of chat+input
887
- # Chatbot Display Area
888
- chatbot = pro.Chatbot(elem_classes="chatbot-chat-messages", # Takes flexible space
889
- # height=0, # Let flexbox control height
890
- value = None, # Initial value empty, loaded by state
891
- welcome_config=welcome_config(),
892
- user_config=user_config(),
893
- bot_config=bot_config())
894
-
895
- # Input Area (Sender)
896
- with antdx.Suggestion(
897
- items=DEFAULT_SUGGESTIONS,
898
- should_trigger="""(e, { onTrigger, onKeyDown }) => {
899
- // Keep existing JS logic for suggestions
900
- switch(e.key) {
901
- case '/': onTrigger(); break;
902
- case 'ArrowRight': case 'ArrowLeft': case 'ArrowUp': case 'ArrowDown': break;
903
- default: onTrigger(false);
904
- }
905
- onKeyDown(e);
906
- }""") as suggestion:
907
- with ms.Slot("children"):
908
- # Use elem_id for easier targeting if needed
909
- with antdx.Sender(elem_id="input-sender",
910
- placeholder=get_text(
911
- "Enter \"/\" to get suggestions, Shift+Enter for newline",
912
- "输入 \"/\" 获取提示,Shift+Enter 换行"),
913
- # interactive=True # Default is True
914
- ) as input:
915
- with ms.Slot("header"):
916
- # Pass AVAILABLE_MODEL_OPTIONS to SettingsHeader
917
- settings_header_state, settings_form = SettingsHeader(
918
- model_options=AVAILABLE_MODEL_OPTIONS, # Pass available options
919
- default_settings=DEFAULT_SETTINGS # Pass defaults
920
- )
921
- with ms.Slot("prefix"):
922
- with antd.Flex(
923
- gap=4,
924
- wrap=True, # Allow wrapping on small screens
925
- elem_style=dict(maxWidth='80vw') # Adjust max width
926
- ):
927
- with antd.Button(
928
- value=None, type="text"
929
- ) as setting_btn:
930
- with ms.Slot("icon"): antd.Icon("SettingOutlined")
931
- with antd.Button(
932
- value=None, type="text"
933
- ) as clear_btn:
934
- with ms.Slot("icon"): antd.Icon("ClearOutlined")
935
- # Keep ThinkingButton if UI uses it, otherwise remove
936
- thinking_btn_state = ThinkingButton()
937
-
938
- # --- Event Handlers ---
939
-
940
- # Browser State Handler (if enabled)
941
- if save_history:
942
- browser_state = gr.BrowserState(
943
- # Define the structure expected from the browser
944
- value={ "conversations": [], "conversation_contexts": {} },
945
- storage_key="vi_legal_chat_demo_storage" # Use a unique key
946
  )
947
- # When Python state changes, update the browser state
948
- state.change(fn=Gradio_Events.update_browser_state,
949
- inputs=[state],
950
- outputs=[browser_state],
951
- queue=False) # Run immediately
952
-
953
- # On page load, apply browser state to Python state and UI
954
- # Note: Ensure outputs match what apply_browser_state returns
955
- demo.load(fn=Gradio_Events.apply_browser_state,
956
- inputs=[browser_state, state],
957
- outputs=[conversations, chatbot, settings_form, state], # Outputs to update UI
958
- queue=False) # Run immediately on load
959
- elif not loaded_models:
960
- # If history saving is off AND no models loaded, show a message
961
- def show_no_model_warning():
962
- gr.Warning("No models were loaded successfully. The application functionality will be limited.")
963
- # You could also update a specific Gradio component to show the error
964
- demo.load(fn=show_no_model_warning, inputs=[], outputs=[])
965
-
966
 
967
- # Conversations Handler
968
- add_conversation_btn.click(fn=Gradio_Events.new_chat,
969
- inputs=[thinking_btn_state, state],
970
- outputs=[ # Match return order of new_chat
971
- conversations, chatbot, settings_form,
972
- thinking_btn_state, state
973
- ])
974
- conversations.active_change(fn=Gradio_Events.select_conversation,
975
- inputs=[thinking_btn_state, state],
976
- outputs=[ # Match return order of select_conversation
977
- conversations, chatbot, settings_form,
978
- thinking_btn_state, state
979
- ])
980
- conversations.menu_click(fn=Gradio_Events.click_conversation_menu,
981
- inputs=[state],
982
- outputs=[ # Match return order of click_conversation_menu
983
- conversations, chatbot, settings_form, state
984
- ],
985
- ) # queue=False ? Might be okay
986
-
987
- # Chatbot Handler
988
- chatbot.welcome_prompt_select(fn=Gradio_Events.apply_prompt,
989
- outputs=[input]) # Update input field
990
-
991
- # Use _js counterpart for direct manipulation if needed, otherwise rely on state change
992
- chatbot.delete(fn=Gradio_Events.delete_message,
993
- inputs=[state],
994
- outputs=[state]) # Only update state, UI will react
995
- chatbot.edit(fn=Gradio_Events.edit_message,
996
- inputs=[state, chatbot], # Pass chatbot value for content
997
- outputs=[state]) # Only update state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
998
 
999
- # Regenerate uses the standard submit flow after trimming history
1000
- regenerating_event = chatbot.retry(
1001
- fn=Gradio_Events.regenerate_message,
1002
- inputs=[settings_form, thinking_btn_state, state],
1003
- outputs=[ # Outputs from preprocess, submit, and postprocess combined
1004
- input, conversations, add_conversation_btn, clear_btn,
1005
- conversation_delete_menu_item, setting_btn, chatbot, state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1006
  ],
1007
- # Ensure outputs match the combined yields of the handler chain
1008
- )
1009
-
1010
 
1011
- # Input Handler
1012
- submit_event = input.submit(
1013
- fn=Gradio_Events.add_message,
1014
- inputs=[input,
1015
- settings_form, thinking_btn_state, state],
1016
- outputs=[ # Outputs from preprocess, submit, and postprocess combined
1017
- input, conversations, add_conversation_btn, clear_btn,
1018
- conversation_delete_menu_item, setting_btn, chatbot, state
1019
- ]) # Ensure outputs match yields
1020
 
1021
- # Cancel needs to unlock UI elements modified by preprocess
1022
- input.cancel(fn=Gradio_Events.cancel,
1023
- inputs=[state],
1024
- outputs=[ # Outputs matching postprocess_submit return dict keys
1025
- input, conversation_delete_menu_item, clear_btn,
1026
- conversations, add_conversation_btn, setting_btn, chatbot, state
1027
- ],
1028
- cancels=[submit_event, regenerating_event], # Cancel ongoing submit/regen
1029
- queue=False) # Run immediately
1030
-
1031
- # Input Actions Handler
1032
- setting_btn.click(fn=Gradio_Events.toggle_settings_header,
1033
- inputs=[settings_header_state],
1034
- outputs=[settings_header_state])
1035
- clear_btn.click(fn=Gradio_Events.clear_conversation_history,
1036
- inputs=[state],
1037
- outputs=[chatbot, state]) # Update chatbot display and state
1038
- suggestion.select(fn=Gradio_Events.select_suggestion,
1039
- inputs=[input],
1040
- outputs=[input]) # Update input field
1041
-
1042
- # --- Launch ---
1043
  if __name__ == "__main__":
1044
- if not loaded_models:
1045
- print("\nWARNING: No models loaded. Gradio app will launch but may not be functional.\n")
1046
- # Optionally prevent launch entirely:
1047
- # print("Exiting because no models were loaded.")
1048
- # exit()
1049
-
1050
- print("Launching Gradio Interface...")
1051
- demo.queue(default_concurrency_limit=10, # Adjust concurrency based on your GPU/CPU resources
1052
- max_size=20).launch(ssr_mode=False, # Consider True if SEO or initial load speed is critical
1053
- # share=True, # Uncomment to create a public link (use with caution)
1054
- # server_name="0.0.0.0" # Uncomment to allow access from network
1055
- max_threads=40 # Gradio default
1056
- )
 
 
 
 
 
1
  import torch
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import gradio as gr
 
 
 
 
 
 
4
 
5
+ # Kiểm tra thiết bị (GPU nếu có)
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ print(f"Using device: {device}") # Thêm log để biết thiết bị đang sử dụng
 
8
 
9
+ # --- Tải mô hình và tokenizer ---
10
+ # Khởi tạo biến model và tokenizer là None
11
+ model_1, tokenizer_1 = None, None
12
+ model_2, tokenizer_2 = None, None
13
+ # model_3, tokenizer_3 = None, None # Không cần tải model_3 nữa
14
+ model_4, tokenizer_4 = None, None
15
 
16
  # Sử dụng try-except để xử lý lỗi nếu không tải được mô hình
17
  try:
 
19
  print(f"Loading model: {model_name_1}...")
20
  model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_1).to(device)
21
  tokenizer_1 = AutoTokenizer.from_pretrained(model_name_1)
 
22
  print(f"Model {model_name_1} loaded successfully.")
23
  except Exception as e:
24
  print(f"Error loading model {model_name_1}: {e}")
 
28
  print(f"Loading model: {model_name_2}...")
29
  model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_2).to(device)
30
  tokenizer_2 = AutoTokenizer.from_pretrained(model_name_2)
 
31
  print(f"Model {model_name_2} loaded successfully.")
32
  except Exception as e:
33
  print(f"Error loading model {model_name_2}: {e}")
34
 
35
  # Bỏ qua việc tải model_3 (ViLawT5_RL)
36
+ # ... (phần code tải model_3 bị comment như cũ) ...
37
 
38
  try:
39
  model_name_4 = "sunbv56/V-LegalQA"
40
  print(f"Loading model: {model_name_4}...")
41
  model_4 = AutoModelForSeq2SeqLM.from_pretrained(model_name_4).to(device)
42
  tokenizer_4 = AutoTokenizer.from_pretrained(model_name_4)
 
43
  print(f"Model {model_name_4} loaded successfully.")
44
  except Exception as e:
45
  print(f"Error loading model {model_name_4}: {e}")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # --- Hàm sinh phản hồi ---
49
+ def chatbot_response(question, model_choice, max_new_tokens, temperature, top_k, top_p, repetition_penalty, use_early_stopping, use_do_sample):
50
+ model = None
51
+ tokenizer = None
52
+
53
+ # Chọn model dựa trên lựa chọn của người dùng (đã bỏ ViLawT5_RL)
54
+ if model_choice == "ViLawT5" and model_1 and tokenizer_1:
55
+ model = model_1
56
+ tokenizer = tokenizer_1
57
+ elif model_choice == "ViT5" and model_2 and tokenizer_2:
58
+ model = model_2
59
+ tokenizer = tokenizer_2
60
+
61
+ # Bỏ điều kiện kiểm tra ViLawT5_RL
62
+ # elif model_choice == "ViLawT5_RL" and model_3 and tokenizer_3:
63
+ # model = model_3
64
+ # tokenizer = tokenizer_3
65
+ elif model_choice == "V-LegalQA" and model_4 and tokenizer_4:
66
+ model = model_4
67
+ tokenizer = tokenizer_4
68
+ else:
69
+ # Kiểm tra xem model có được tải không
70
+ available_models = []
71
+ if model_1: available_models.append("ViLawT5")
72
+ if model_2: available_models.append("ViT5")
73
+ # Không thêm ViLawT5_RL vào danh sách kiểm tra
74
+ if model_4: available_models.append("V-LegalQA")
75
+
76
+ if not available_models:
77
+ return "Error: No models were loaded successfully. Please check the logs."
78
+ if model_choice not in available_models:
79
+ return f"Error: Model '{model_choice}' was not loaded successfully or is invalid. Available models: {', '.join(available_models)}"
80
+ else: # Trường hợp model_choice hợp lệ nhưng model/tokenizer là None (lỗi không mong muốn)
81
+ return f"Error: An unexpected issue occurred with model '{model_choice}'. Please check the logs."
82
+
83
+ print(f"Generating response using {model_choice} with params: max_new_tokens={max_new_tokens}, temp={temperature}, top_k={top_k}, top_p={top_p}, rep_penalty={repetition_penalty}, early_stop={use_early_stopping}, do_sample={use_do_sample}")
84
+
85
+ input_text = f"câu_hỏi: {question}"
86
+ try:
87
+ data = tokenizer(
88
+ input_text,
89
+ return_tensors="pt",
90
+ truncation=True,
91
+ return_attention_mask=True,
92
+ add_special_tokens=True,
93
+ padding="max_length",
94
+ max_length=256 # Cân nhắc tăng max_length nếu câu hỏi/context dài
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ input_ids = data.input_ids.to(device)
98
+ attention_mask = data.attention_mask.to(device)
99
+
100
+ # Suy luận với hình
101
+ with torch.no_grad():
102
+ outputs = model.generate(
103
+ input_ids,
104
+ attention_mask=attention_mask,
105
+ max_new_tokens=int(max_new_tokens),
106
+ early_stopping=use_early_stopping,
107
+ do_sample=use_do_sample,
108
+ temperature=float(temperature),
109
+ top_k=int(top_k),
110
+ top_p=float(top_p),
111
+ repetition_penalty=float(repetition_penalty),
112
+ # Thêm pad_token_id nếu cần (thường không cần cho T5)
113
+ # pad_token_id=tokenizer.pad_token_id
114
+ )
115
+
116
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+ print(f"Raw output shape: {outputs[0].shape}") # Log thêm shape
118
+ print(f"Decoded response: {response}")
119
+ return response
120
+ except Exception as e:
121
+ print(f"Error during generation: {e}")
122
+ # In thêm traceback để debug
123
+ import traceback
124
+ traceback.print_exc()
125
+ return f"An error occurred during response generation: {e}"
126
+
127
+ # --- Tạo danh sách các model đã tải thành công (bỏ ViLawT5_RL) ---
128
+ loaded_models = []
129
+ if model_1 and tokenizer_1: loaded_models.append("ViLawT5")
130
+ if model_2 and tokenizer_2: loaded_models.append("ViT5")
131
+ if model_4 and tokenizer_4: loaded_models.append("V-LegalQA")
132
+
133
+ # Chọn model mặc định
134
+ default_model = "V-LegalQA" if "V-LegalQA" in loaded_models else (loaded_models[0] if loaded_models else "No models available")
135
+
136
+ # --- Tạo giao diện với Gradio ---
137
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
138
+ gr.Markdown(
139
+ """
140
+ # 🤖 AI Chatbot Pháp luật Việt Nam (Demo)
141
+ Chọn mô hình và đặt câu hỏi liên quan đến pháp luật.
142
+ Nhấn **Shift + Enter** để gửi câu hỏi, **Enter** để xuống dòng.
143
+ """
144
+ )
145
+
146
+ with gr.Row():
147
+ model_choice = gr.Dropdown(
148
+ choices=loaded_models,
149
+ label="Chọn Mô hình AI",
150
+ value=default_model,
151
+ interactive=bool(loaded_models) # Chỉ cho phép tương tác nếu có model
152
+ )
153
 
154
+ # Đảm bảo 'lines' >= 2 để Shift+Enter tác dụng rõ ràng
155
+ question_input = gr.Textbox(
156
+ label="Nhập câu hỏi của bạn (Shift+Enter để gửi)",
157
+ placeholder="Ví dụ: Thế nào là tội cố ý gây thương tích?",
158
+ lines=3, # Giữ nguyên hoặc tăng nếu muốn ô nhập cao hơn
159
+ # scale=7 # Ví dụ: làm cho ô nhập rộng hơn nếu cần
160
+ )
161
+
162
+ # --- Cập nhật giá trị mặc định trong Accordion ---
163
+ with gr.Accordion("Tùy chọn Nâng cao (Generation Parameters)", open=False):
164
+ with gr.Row():
165
+ early_stopping_checkbox = gr.Checkbox(label="Enable Early Stopping", value=False, info="Dừng sớm khi gặp token EOS.")
166
+ do_sample_checkbox = gr.Checkbox(label="Enable Sampling (do_sample)", value=False, info="Sử dụng sampling (cần thiết cho temperature, top_k, top_p). Tắt nếu muốn greedy search.")
167
+ with gr.Row():
168
+ max_new_tokens_slider = gr.Slider(minimum=10, maximum=1024, value=512, step=10, label="Max New Tokens", info="Số lượng token tối đa được sinh ra.")
169
+ temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature", info="Độ 'sáng tạo' của câu trả lời (thấp hơn = bảo thủ hơn). Cần bật do_sample.")
170
+ with gr.Row():
171
+ top_k_slider = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Top-K", info="Chỉ xem xét K token có xác suất cao nhất. Cần bật do_sample.")
172
+ top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.01, label="Top-P (Nucleus Sampling)", info="Chỉ xem xét các token có tổng xác suất >= P. Cần bật do_sample.")
173
+ repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=3.0, value=1.0, step=0.1, label="Repetition Penalty", info="Phạt các token đã xuất hiện (cao hơn = ít lặp lại hơn).")
174
+
175
+
176
+ response_output = gr.Textbox(label="Phản hồi của Chatbot", lines=5, interactive=False)
177
+
178
+ # Nút gửi vẫn giữ lại phòng trường hợp người dùng thích click hơn
179
+ submit_btn = gr.Button("Gửi câu hỏi", variant="primary")
180
+
181
+ # --- THAY ĐỔI QUAN TRỌNG ---
182
+ # Tạo một list các inputs để dùng chung cho cả nút bấm và nhấn Enter
183
+ chatbot_inputs = [
184
+ question_input,
185
+ model_choice,
186
+ max_new_tokens_slider,
187
+ temperature_slider,
188
+ top_k_slider,
189
+ top_p_slider,
190
+ repetition_penalty_slider,
191
+ early_stopping_checkbox,
192
+ do_sample_checkbox
193
+ ]
194
+
195
+ # 1. Gửi khi nhấn nút
196
+ submit_btn.click(
197
+ fn=chatbot_response,
198
+ inputs=chatbot_inputs,
199
+ outputs=response_output
200
+ )
201
+
202
+ # 2. Gửi khi nhấn Enter trong Textbox question_input
203
+ # Shift+Enter sẽ tự động xuống dòng (hành vi mặc định khi lines > 1)
204
+ question_input.submit(
205
+ fn=chatbot_response,
206
+ inputs=chatbot_inputs,
207
+ outputs=response_output
208
+ )
209
+ # -----------------------------
210
+
211
+ gr.Examples(
212
+ examples=[
213
+ ["Hợp đồng vô hiệu khi nào?", "V-LegalQA"],
214
+ ["Quyền và nghĩa vụ của người lao động là gì?", "ViT5"],
215
+ ["Người dưới 18 tuổi có được ký hợp đồng lao động không?\nThời gian làm việc tối đa là bao lâu?", "V-LegalQA"] # Ví dụ multi-line
216
  ],
217
+ inputs=[question_input, model_choice]
218
+ )
 
219
 
 
 
 
 
 
 
 
 
 
220
 
221
+ # --- Chạy Gradio ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  if __name__ == "__main__":
223
+ if not loaded_models:
224
+ print("WARNING: No models were loaded successfully. The application might not function correctly.")
225
+ # Cân nhắc thêm: gr.Info("Không có mô hình nào được tải thành công!") trong Blocks
226
+ # Bật share=True nếu muốn tạo link chia sẻ tạm thời
227
+ demo.launch(debug=True, share=False)