Spaces:
Sleeping
Sleeping
Commit ·
4cfe4fa
1
Parent(s): 2eebc40
Deploy CQL Chatbot (without large files)
Browse files- .gitignore +53 -0
- Agent_Diffusion/inference_v3.py +198 -0
- Agent_Diffusion/stable_diffusion.py +174 -0
- Agent_Diffusion/train_unet.py +331 -0
- Agent_Diffusion/train_vae.py +159 -0
- Agent_Diffusion/unet-mini.safetensors +3 -0
- Agent_Diffusion/vae-finetuned.safetensors +3 -0
- Conservative Q-learning/Agen1_training.py +114 -0
- Conservative Q-learning/cql_agent.py +242 -0
- Conservative Q-learning/cql_utils.py +82 -0
- Conservative Q-learning/saved_agent_1/cql_model.pth +3 -0
- Conservative Q-learning/saved_agent_1/normalizer.pkl +3 -0
- app.py +432 -0
- chatbot_engine.py +238 -0
- communication_agent.py +96 -0
- config.py +51 -0
- drawing_agent.py +208 -0
- memory_manager.py +172 -0
- requirements.txt +12 -0
.gitignore
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python cache
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
|
| 11 |
+
# Virtual Environment
|
| 12 |
+
.venv/
|
| 13 |
+
venv/
|
| 14 |
+
ENV/
|
| 15 |
+
|
| 16 |
+
# IDE
|
| 17 |
+
.vscode/
|
| 18 |
+
.idea/
|
| 19 |
+
*.swp
|
| 20 |
+
*.swo
|
| 21 |
+
|
| 22 |
+
# Environment variables (IMPORTANT: Don't upload API keys!)
|
| 23 |
+
.env
|
| 24 |
+
|
| 25 |
+
# Conversation history (user data)
|
| 26 |
+
conversation_history/
|
| 27 |
+
|
| 28 |
+
# Generated images (will be created on server)
|
| 29 |
+
generated_images/
|
| 30 |
+
|
| 31 |
+
# Model cache
|
| 32 |
+
.cache/
|
| 33 |
+
|
| 34 |
+
# Logs
|
| 35 |
+
*.log
|
| 36 |
+
|
| 37 |
+
# OS
|
| 38 |
+
.DS_Store
|
| 39 |
+
Thumbs.db
|
| 40 |
+
|
| 41 |
+
# Large model files (upload separately to HF)
|
| 42 |
+
# Uncomment if models are too large for git
|
| 43 |
+
# Agent_Diffusion/*.safetensors
|
| 44 |
+
# Conservative Q-learning/saved_agent_1/*.pth
|
| 45 |
+
# Conservative Q-learning/saved_agent_1/*.pkl
|
| 46 |
+
|
| 47 |
+
# Backup files
|
| 48 |
+
*_backup.py
|
| 49 |
+
*_gpt2.py
|
| 50 |
+
communication_agent_gemini_backup.py
|
| 51 |
+
|
| 52 |
+
# Antigravity artifacts
|
| 53 |
+
.gemini/
|
Agent_Diffusion/inference_v3.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Swin2SRImageProcessor, Swin2SRForImageSuperResolution
|
| 4 |
+
from diffusers import StableDiffusionPipeline
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image, ImageEnhance, ImageOps
|
| 7 |
+
import random
|
| 8 |
+
# from safetensors.torch import load_file
|
| 9 |
+
|
| 10 |
+
from stable_diffusion import MiniDiffusionPipeline
|
| 11 |
+
|
| 12 |
+
# --- Cấu hình ---
|
| 13 |
+
#PROMPT = "beautiful woman with long braided hair, wearing a scarf, soft smile, looking down, detailed shading" #725562173
|
| 14 |
+
#PROMPT = "attractive woman, big lips, mouth slightly open, heavy makeup" #v5
|
| 15 |
+
#PROMPT = "The man is young and has sharp jawline, narrow eyes, thick eyebrows, and short black hair." #10, 11
|
| 16 |
+
#PROMPT = "She is elderly with deep smile lines, small eyes, and short curly gray hair." #13
|
| 17 |
+
#PROMPT = "This man is old and smiling, with gray beard and big nose"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
PROMPT = "a baby"
|
| 21 |
+
SAVE_IMAGE_PATH = "./15.png"
|
| 22 |
+
|
| 23 |
+
UNET_SAFE_PATH = "./unet-mini.safetensors"
|
| 24 |
+
VAE_SAFE_PATH = "./vae-finetuned.safetensors"
|
| 25 |
+
|
| 26 |
+
BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"
|
| 27 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
TINY_UNET_CONFIG = {
|
| 31 |
+
"unet_block_out_channels": (128, 256, 512),
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
MODEL_ID = "caidas/swin2SR-classical-sr-x4-64"
|
| 35 |
+
print(f"Đang load model {MODEL_ID} từ Hugging Face...")
|
| 36 |
+
processor = Swin2SRImageProcessor.from_pretrained(MODEL_ID)
|
| 37 |
+
model = Swin2SRForImageSuperResolution.from_pretrained(MODEL_ID)
|
| 38 |
+
print("Load model thành công!")
|
| 39 |
+
|
| 40 |
+
model = model.to(DEVICE)
|
| 41 |
+
|
| 42 |
+
def upscale_image_pipeline(pil_image, contrast=1.3, sharpen=1.5, target_size=(512, 512)):
|
| 43 |
+
"""
|
| 44 |
+
Chiến thuật "Canvas Isolation" (Cách ly khung tranh).
|
| 45 |
+
Đặt ảnh vào giữa một vùng trắng cực rộng để đẩy lỗi biên ra xa.
|
| 46 |
+
"""
|
| 47 |
+
if model is None or processor is None:
|
| 48 |
+
return pil_image.resize(target_size)
|
| 49 |
+
|
| 50 |
+
# 1. Chuẩn bị ảnh
|
| 51 |
+
img_np = np.array(pil_image)
|
| 52 |
+
if len(img_np.shape) == 2:
|
| 53 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
|
| 54 |
+
|
| 55 |
+
h_orig, w_orig = img_np.shape[:2]
|
| 56 |
+
|
| 57 |
+
# 2. TẠO CANVAS (Khung tranh) LỚN
|
| 58 |
+
# Tạo một nền trắng to gấp đôi ảnh gốc (256x256)
|
| 59 |
+
# Mục đích: Đưa ảnh thật vào "vùng an toàn" ở trung tâm tuyệt đối
|
| 60 |
+
canvas_size = 256
|
| 61 |
+
canvas = np.ones((canvas_size, canvas_size, 3), dtype=np.uint8) * 255
|
| 62 |
+
|
| 63 |
+
# Tính tọa độ để dán ảnh vào giữa
|
| 64 |
+
y_offset = (canvas_size - h_orig) // 2 # (256-128)/2 = 64
|
| 65 |
+
x_offset = (canvas_size - w_orig) // 2 # 64
|
| 66 |
+
|
| 67 |
+
# Dán ảnh vào canvas
|
| 68 |
+
canvas[y_offset:y_offset+h_orig, x_offset:x_offset+w_orig] = img_np
|
| 69 |
+
|
| 70 |
+
# 3. Upscale toàn bộ Canvas
|
| 71 |
+
# Lúc này model sẽ xử lý biên của ảnh 256x256 -> Lỗi phản chiếu sẽ nằm ở rìa canvas (cách ảnh thật rất xa)
|
| 72 |
+
pil_canvas = Image.fromarray(canvas)
|
| 73 |
+
inputs = processor(pil_canvas, return_tensors="pt").to(DEVICE)
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
outputs = model(**inputs)
|
| 77 |
+
|
| 78 |
+
output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 79 |
+
output_tensor = np.moveaxis(output_tensor, 0, -1)
|
| 80 |
+
output_canvas = (output_tensor * 255.0).round().astype(np.uint8)
|
| 81 |
+
|
| 82 |
+
# 4. TRÍCH XUẤT ẢNH THẬT (CROP)
|
| 83 |
+
# Canvas input 256 -> Upscale x4 -> Canvas output 1024
|
| 84 |
+
# Ảnh thật nằm ở vị trí offset * 4
|
| 85 |
+
scale_factor = 4
|
| 86 |
+
y_start = y_offset * scale_factor # 64 * 4 = 256
|
| 87 |
+
x_start = x_offset * scale_factor # 256
|
| 88 |
+
|
| 89 |
+
# Kích thước ảnh thật sau khi upscale (128 * 4 = 512)
|
| 90 |
+
h_real = h_orig * scale_factor
|
| 91 |
+
w_real = w_orig * scale_factor
|
| 92 |
+
|
| 93 |
+
# Cắt lấy đúng phần ảnh thật nằm giữa canvas
|
| 94 |
+
final_img = output_canvas[y_start : y_start + h_real, x_start : x_start + w_real]
|
| 95 |
+
|
| 96 |
+
# 5. BIỆN PHÁP CƯỠNG BỨC (HARD FIX)
|
| 97 |
+
# Nếu model vẫn "lì lợm" tạo ra 1-2 pixel mờ ở đáy, ta sẽ tô trắng 3 dòng pixel cuối cùng.
|
| 98 |
+
# Vì đây là tranh vẽ trên nền trắng, việc này không ảnh hưởng nội dung nhưng xóa sạch mọi lỗi.
|
| 99 |
+
final_img[-1:, :, :] = 255
|
| 100 |
+
final_img[:, -1:, :] = 255
|
| 101 |
+
|
| 102 |
+
# 6. Đảm bảo kích thước cuối cùng
|
| 103 |
+
if final_img.shape[:2] != target_size:
|
| 104 |
+
final_img = cv2.resize(final_img, (target_size[1], target_size[0]), interpolation=cv2.INTER_LANCZOS4)
|
| 105 |
+
|
| 106 |
+
# 7. Hậu xử lý
|
| 107 |
+
final_pil = Image.fromarray(final_img)
|
| 108 |
+
enhancer = ImageEnhance.Contrast(final_pil)
|
| 109 |
+
final_pil = enhancer.enhance(contrast)
|
| 110 |
+
enhancer = ImageEnhance.Sharpness(final_pil)
|
| 111 |
+
final_pil = enhancer.enhance(sharpen)
|
| 112 |
+
|
| 113 |
+
return final_pil
|
| 114 |
+
|
| 115 |
+
@torch.no_grad()
|
| 116 |
+
def main():
|
| 117 |
+
print("--- Bắt đầu quá trình Inference (từ Safetensors) ---")
|
| 118 |
+
|
| 119 |
+
# --- Khởi tạo MiniDiffusionPipeline ---
|
| 120 |
+
print(f"Đang tải pipeline gốc từ {BASE_MODEL_ID}...")
|
| 121 |
+
container = MiniDiffusionPipeline(
|
| 122 |
+
base_model_id=BASE_MODEL_ID,
|
| 123 |
+
device=DEVICE,
|
| 124 |
+
config_overrides=TINY_UNET_CONFIG
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# --- Tải trọng số đã huấn luyện ---
|
| 128 |
+
|
| 129 |
+
# Tải UNet
|
| 130 |
+
print(f"Đang tải trọng số UNet từ {UNET_SAFE_PATH}...")
|
| 131 |
+
try:
|
| 132 |
+
unet_weights = torch.load(UNET_SAFE_PATH, map_location=DEVICE)
|
| 133 |
+
container.unet.load_state_dict(unet_weights)
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f"LỖI: Không thể tải UNet state dict: {e}")
|
| 136 |
+
print("Kiểm tra xem bạn đã bỏ chú thích 'config_overrides=TINY_UNET_CONFIG' chưa?")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
# Tải VAE
|
| 140 |
+
print(f"Đang tải trọng số VAE từ {VAE_SAFE_PATH}...")
|
| 141 |
+
try:
|
| 142 |
+
vae_weights = torch.load(VAE_SAFE_PATH, map_location=DEVICE)
|
| 143 |
+
container.vae.load_state_dict(vae_weights)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"LỖI: Không thể tải VAE state dict: {e}")
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
# --- Khởi tạo StableDiffusionPipeline ---
|
| 149 |
+
torch_dtype = torch.float16 if DEVICE.startswith("cuda") else torch.float32
|
| 150 |
+
|
| 151 |
+
print("Đang tạo pipeline inference...")
|
| 152 |
+
inference_pipeline = StableDiffusionPipeline(
|
| 153 |
+
unet=container.unet,
|
| 154 |
+
vae=container.vae,
|
| 155 |
+
text_encoder=container.text_encoder,
|
| 156 |
+
tokenizer=container.tokenizer,
|
| 157 |
+
scheduler=container.noise_scheduler,
|
| 158 |
+
safety_checker=None,
|
| 159 |
+
feature_extractor=None,
|
| 160 |
+
).to(DEVICE)
|
| 161 |
+
|
| 162 |
+
if DEVICE.startswith("cuda"):
|
| 163 |
+
inference_pipeline.to(dtype=torch_dtype)
|
| 164 |
+
|
| 165 |
+
inference_pipeline.set_progress_bar_config(disable=False)
|
| 166 |
+
|
| 167 |
+
# --- Tạo ảnh ---
|
| 168 |
+
print(f"\nĐang tạo ảnh cho prompt: '{PROMPT}'")
|
| 169 |
+
current_seed = random.randint(0, 2**32 - 1)
|
| 170 |
+
print(f"Seed hiện tại: {current_seed}")
|
| 171 |
+
|
| 172 |
+
generator = torch.Generator(device=DEVICE).manual_seed(current_seed) #725562173, 4169604779, 725562172, 3884820838, 1794046812, 1379970385
|
| 173 |
+
|
| 174 |
+
image = inference_pipeline(
|
| 175 |
+
prompt=PROMPT,
|
| 176 |
+
num_inference_steps=50,
|
| 177 |
+
generator=generator,
|
| 178 |
+
guidance_scale=7.5
|
| 179 |
+
).images[0]
|
| 180 |
+
|
| 181 |
+
final_image = upscale_image_pipeline(image)
|
| 182 |
+
final_image.save(SAVE_IMAGE_PATH)
|
| 183 |
+
|
| 184 |
+
# --- Lưu ảnh ---
|
| 185 |
+
image.save(SAVE_IMAGE_PATH.replace(".png", "_original.png"))
|
| 186 |
+
|
| 187 |
+
# # --- Lưu ảnh ---
|
| 188 |
+
# image.save(SAVE_IMAGE_PATH)
|
| 189 |
+
print(f"\n--- Hoàn thành! ---")
|
| 190 |
+
print(f"Đã lưu ảnh tại: {SAVE_IMAGE_PATH}")
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
image.show()
|
| 194 |
+
except Exception:
|
| 195 |
+
pass
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
Agent_Diffusion/stable_diffusion.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers.models import UNet2DConditionModel, AutoencoderKL
|
| 3 |
+
from diffusers.schedulers import DDPMScheduler
|
| 4 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 5 |
+
from typing import Dict, Any, Optional
|
| 6 |
+
|
| 7 |
+
class MiniDiffusionPipeline:
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# config mặc định
|
| 11 |
+
DEFAULT_CONFIG: Dict[str, Any] = {
|
| 12 |
+
"beta_schedule": "scaled_linear",
|
| 13 |
+
"beta_start": 0.00085,
|
| 14 |
+
"beta_end": 0.0120,
|
| 15 |
+
"num_train_timesteps": 1000,
|
| 16 |
+
"prediction_type": "epsilon",
|
| 17 |
+
"variance_type": "fixed_small",
|
| 18 |
+
"clip_sample": False,
|
| 19 |
+
"rescale_betas_zero_snr": False,
|
| 20 |
+
"timestep_spacing": "leading",
|
| 21 |
+
"lr": 1e-4,
|
| 22 |
+
"optimizer": "AdamW",
|
| 23 |
+
"scheduler": "cosine",
|
| 24 |
+
"ema_decay": 0.9999,
|
| 25 |
+
"latent_scale": 0.18215,
|
| 26 |
+
"text_embed_dim": 768,
|
| 27 |
+
"latent_channels": 4,
|
| 28 |
+
"latent_downscale_factor": 8,
|
| 29 |
+
|
| 30 |
+
# --- Cấu hình kiến trúc UNet-mini ---
|
| 31 |
+
"image_size": 128,
|
| 32 |
+
"unet_block_out_channels": (256, 512, 1024),
|
| 33 |
+
"unet_layers_per_block": 1,
|
| 34 |
+
"unet_down_block_types": (
|
| 35 |
+
"CrossAttnDownBlock2D",
|
| 36 |
+
"CrossAttnDownBlock2D",
|
| 37 |
+
"DownBlock2D",
|
| 38 |
+
),
|
| 39 |
+
"unet_up_block_types": (
|
| 40 |
+
"UpBlock2D",
|
| 41 |
+
"CrossAttnUpBlock2D",
|
| 42 |
+
"CrossAttnUpBlock2D",
|
| 43 |
+
),
|
| 44 |
+
"unet_mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 45 |
+
"unet_attention_head_dim": 8,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
base_model_id: str = "stabilityai/stable-diffusion-v1-5",
|
| 51 |
+
vae_model_id: Optional[str] = None,
|
| 52 |
+
device: str = "cpu",
|
| 53 |
+
config_overrides: Optional[Dict[str, Any]] = None
|
| 54 |
+
):
|
| 55 |
+
self.device = torch.device(device)
|
| 56 |
+
|
| 57 |
+
self.config = {**self.DEFAULT_CONFIG, **(config_overrides or {})}
|
| 58 |
+
|
| 59 |
+
print(f"Đang tải Tokenizer và Text Encoder (đã đóng băng) từ {base_model_id}...")
|
| 60 |
+
self.tokenizer = self._load_tokenizer(base_model_id)
|
| 61 |
+
self.text_encoder = self._load_text_encoder(base_model_id)
|
| 62 |
+
|
| 63 |
+
_vae_id = vae_model_id or base_model_id
|
| 64 |
+
_vae_subfolder = "vae" if vae_model_id is None else None
|
| 65 |
+
print(f"Đang tải VAE (để fine-tune) từ {_vae_id}...")
|
| 66 |
+
self.vae = self._load_vae(_vae_id, _vae_subfolder)
|
| 67 |
+
|
| 68 |
+
print("Khởi tạo UNet-mini (với trọng số ngẫu nhiên)...")
|
| 69 |
+
self.unet = self._load_mini_unet()
|
| 70 |
+
|
| 71 |
+
print("Khởi tạo Noise Scheduler...")
|
| 72 |
+
self.noise_scheduler = self._load_noise_scheduler()
|
| 73 |
+
|
| 74 |
+
print("\n--- MiniDiffusionPipeline đã sẵn sàng! ---")
|
| 75 |
+
self.print_model_stats()
|
| 76 |
+
|
| 77 |
+
def _load_tokenizer(self, model_id: str) -> CLIPTokenizer:
|
| 78 |
+
return CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
|
| 79 |
+
|
| 80 |
+
def _load_text_encoder(self, model_id: str) -> CLIPTextModel:
|
| 81 |
+
model = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
|
| 82 |
+
model.to(self.device)
|
| 83 |
+
model.requires_grad_(False)
|
| 84 |
+
return model
|
| 85 |
+
|
| 86 |
+
def _load_vae(self, model_id: str, subfolder: Optional[str]) -> AutoencoderKL:
|
| 87 |
+
if subfolder:
|
| 88 |
+
model = AutoencoderKL.from_pretrained(model_id, subfolder=subfolder)
|
| 89 |
+
else:
|
| 90 |
+
model = AutoencoderKL.from_pretrained(model_id)
|
| 91 |
+
model.to(self.device)
|
| 92 |
+
return model
|
| 93 |
+
|
| 94 |
+
def _load_mini_unet(self) -> UNet2DConditionModel:
|
| 95 |
+
|
| 96 |
+
latent_size = self.config["image_size"] // self.config["latent_downscale_factor"]
|
| 97 |
+
|
| 98 |
+
unet_config = {
|
| 99 |
+
"sample_size": latent_size,
|
| 100 |
+
"in_channels": self.config["latent_channels"],
|
| 101 |
+
"out_channels": self.config["latent_channels"],
|
| 102 |
+
"block_out_channels": self.config["unet_block_out_channels"],
|
| 103 |
+
"layers_per_block": self.config["unet_layers_per_block"],
|
| 104 |
+
"down_block_types": self.config["unet_down_block_types"],
|
| 105 |
+
"up_block_types": self.config["unet_up_block_types"],
|
| 106 |
+
"mid_block_type": self.config["unet_mid_block_type"],
|
| 107 |
+
"cross_attention_dim": self.config["text_embed_dim"],
|
| 108 |
+
"attention_head_dim": self.config["unet_attention_head_dim"],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
model = UNet2DConditionModel(**unet_config)
|
| 112 |
+
model.to(self.device)
|
| 113 |
+
return model
|
| 114 |
+
|
| 115 |
+
def _load_noise_scheduler(self) -> DDPMScheduler:
|
| 116 |
+
return DDPMScheduler.from_config(self.config)
|
| 117 |
+
|
| 118 |
+
def print_model_stats(self):
|
| 119 |
+
unet_params = sum(p.numel() for p in self.unet.parameters() if p.requires_grad)
|
| 120 |
+
vae_params = sum(p.numel() for p in self.vae.parameters() if p.requires_grad)
|
| 121 |
+
print(f" UNet-mini (để train): {unet_params / 1_000_000:.2f} triệu tham số")
|
| 122 |
+
print(f" VAE (để fine-tune): {vae_params / 1_000_000:.2f} triệu tham số")
|
| 123 |
+
|
| 124 |
+
def get_trainable_parameters(self) -> Dict[str, Any]:
|
| 125 |
+
return {
|
| 126 |
+
"unet": self.unet.parameters(),
|
| 127 |
+
"vae": self.vae.parameters()
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# --- KHỐI KI��M THỬ (SMOKE TEST) ---
|
| 132 |
+
def _run_smoke_test():
|
| 133 |
+
print("--- Bắt đầu kiểm thử MiniDiffusionPipeline ---")
|
| 134 |
+
|
| 135 |
+
if not torch.cuda.is_available():
|
| 136 |
+
print("CẢNH BÁO: Không tìm thấy CUDA. Chạy trên CPU (sẽ chậm).")
|
| 137 |
+
device = "cpu"
|
| 138 |
+
else:
|
| 139 |
+
device = "cuda"
|
| 140 |
+
|
| 141 |
+
# --- Tải mặc định (dùng VAE của 1.5) ---
|
| 142 |
+
print("\n--- Tải mặc định ---")
|
| 143 |
+
pipeline_1 = MiniDiffusionPipeline(
|
| 144 |
+
base_model_id="runwayml/stable-diffusion-v1-5",
|
| 145 |
+
device=device
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# --- Tải VAE-MSE ---
|
| 149 |
+
print("\n--- Tải VAE-MSE tùy chỉnh ---")
|
| 150 |
+
pipeline_2 = MiniDiffusionPipeline(
|
| 151 |
+
base_model_id="runwayml/stable-diffusion-v1-5",
|
| 152 |
+
vae_model_id="stabilityai/sd-vae-ft-mse",
|
| 153 |
+
device=device
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# --- Ghi đè config ---
|
| 157 |
+
print("\n--- Ghi đè config (UNet siêu nhỏ) ---")
|
| 158 |
+
tiny_config = {
|
| 159 |
+
"unet_block_out_channels": (128, 256, 512),
|
| 160 |
+
"lr": 5e-5
|
| 161 |
+
}
|
| 162 |
+
pipeline_3 = MiniDiffusionPipeline(
|
| 163 |
+
base_model_id="runwayml/stable-diffusion-v1-5",
|
| 164 |
+
device=device,
|
| 165 |
+
config_overrides=tiny_config
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
print("\n--- Kiểm thử thành công ---")
|
| 169 |
+
print(f"Config LR của Pipeline 1: {pipeline_1.config['lr']}")
|
| 170 |
+
print(f"Config LR của Pipeline 3: {pipeline_3.config['lr']}")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
_run_smoke_test()
|
Agent_Diffusion/train_unet.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from torch.optim import AdamW
|
| 5 |
+
from diffusers.optimization import get_scheduler
|
| 6 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
from stable_diffusion import MiniDiffusionPipeline
|
| 15 |
+
from dataset import SketchDataset
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
| 19 |
+
from torchmetrics.multimodal.clip_score import CLIPScore
|
| 20 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# --- Cấu hình ---
|
| 24 |
+
TRAIN_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\train"
|
| 25 |
+
VAL_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\val"
|
| 26 |
+
VAE_PATH = "./vae-finetuned.safetensors"
|
| 27 |
+
IMAGE_SIZE = 128
|
| 28 |
+
EPOCHS = 101
|
| 29 |
+
BATCH_SIZE = 16 * 5
|
| 30 |
+
LEARNING_RATE = 1e-4 * 5
|
| 31 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
SAVE_UNET_PATH = "./unet-mini.safetensors"
|
| 33 |
+
|
| 34 |
+
CHECKPOINT_PATH = "./unet_latest_checkpoint.pth"
|
| 35 |
+
NUM_INFERENCE_STEPS = 50
|
| 36 |
+
|
| 37 |
+
TINY_UNET_CONFIG = {
|
| 38 |
+
"unet_block_out_channels": (128, 256, 512),
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def plot_metrics(history, filename="unet_metrics_plot_v1.png"):
|
| 42 |
+
plt.rcParams.update({'font.size': 17})
|
| 43 |
+
|
| 44 |
+
fig, axs = plt.subplots(2, 2, figsize=(15, 12))
|
| 45 |
+
|
| 46 |
+
axs[0, 0].plot(history['train_loss'], label="Train Loss")
|
| 47 |
+
axs[0, 0].plot(history['val_loss'], label="Validation Loss")
|
| 48 |
+
axs[0, 0].set_title("Train vs Validation Loss")
|
| 49 |
+
axs[0, 0].set_xlabel("Epoch")
|
| 50 |
+
axs[0, 0].set_ylabel("MSE Loss")
|
| 51 |
+
axs[0, 0].grid()
|
| 52 |
+
axs[0, 0].legend()
|
| 53 |
+
|
| 54 |
+
axs[0, 1].plot(history['fid'], label="FID", color='green')
|
| 55 |
+
axs[0, 1].set_title("Fréchet Inception Distance (FID)")
|
| 56 |
+
axs[0, 1].set_xlabel("Epoch")
|
| 57 |
+
axs[0, 1].set_ylabel("FID (lower is better)")
|
| 58 |
+
axs[0, 1].grid()
|
| 59 |
+
axs[0, 1].legend()
|
| 60 |
+
|
| 61 |
+
axs[1, 0].plot(history['lpips'], label="LPIPS", color='red')
|
| 62 |
+
axs[1, 0].set_title("Learned Perceptual Image Patch Similarity (LPIPS)")
|
| 63 |
+
axs[1, 0].set_xlabel("Epoch")
|
| 64 |
+
axs[1, 0].set_ylabel("LPIPS (lower is better)")
|
| 65 |
+
axs[1, 0].grid()
|
| 66 |
+
axs[1, 0].legend()
|
| 67 |
+
|
| 68 |
+
axs[1, 1].plot(history['clip_score'], label="CLIP Score", color='purple')
|
| 69 |
+
axs[1, 1].set_title("CLIP Score")
|
| 70 |
+
axs[1, 1].set_xlabel("Epoch")
|
| 71 |
+
axs[1, 1].set_ylabel("CLIP Score (higher is better)")
|
| 72 |
+
axs[1, 1].grid()
|
| 73 |
+
axs[1, 1].legend()
|
| 74 |
+
|
| 75 |
+
plt.tight_layout()
|
| 76 |
+
plt.savefig(filename)
|
| 77 |
+
print(f"Đã lưu biểu đồ metrics tại {filename}")
|
| 78 |
+
|
| 79 |
+
def evaluate(
|
| 80 |
+
eval_pipeline, gen, val_loader, metrics,
|
| 81 |
+
unet, vae, text_encoder, scheduler,
|
| 82 |
+
vae_scale_factor, num_inference_steps
|
| 83 |
+
):
|
| 84 |
+
|
| 85 |
+
unet.eval()
|
| 86 |
+
total_val_loss = 0.0
|
| 87 |
+
|
| 88 |
+
for metric in metrics.values():
|
| 89 |
+
metric.reset()
|
| 90 |
+
|
| 91 |
+
def to_uint8(images):
|
| 92 |
+
images = (images.clamp(-1, 1) + 1) / 2
|
| 93 |
+
images = (images * 255).type(torch.uint8)
|
| 94 |
+
return images
|
| 95 |
+
|
| 96 |
+
def to_lpips_format(images):
|
| 97 |
+
return images.clamp(-1, 1)
|
| 98 |
+
|
| 99 |
+
pbar = tqdm(val_loader, desc="[Validation & Evaluation]")
|
| 100 |
+
for batch in pbar:
|
| 101 |
+
images = batch["pixel_values"].to(DEVICE)
|
| 102 |
+
input_ids = batch["input_ids"].to(DEVICE)
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
# --- TÍNH VALIDATION LOSS ---
|
| 106 |
+
latents = vae.encode(images).latent_dist.mean * vae_scale_factor
|
| 107 |
+
noise = torch.randn_like(latents)
|
| 108 |
+
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=DEVICE)
|
| 109 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
| 110 |
+
text_embeds = text_encoder(input_ids)[0]
|
| 111 |
+
|
| 112 |
+
noise_pred = unet(noisy_latents, timesteps, text_embeds).sample
|
| 113 |
+
val_loss = F.mse_loss(noise_pred, noise)
|
| 114 |
+
total_val_loss += val_loss.item()
|
| 115 |
+
|
| 116 |
+
# --- SINH ẢNH (Dùng eval_pipeline) ---
|
| 117 |
+
prompts = eval_pipeline.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
| 118 |
+
|
| 119 |
+
generated_output = eval_pipeline(
|
| 120 |
+
prompt=prompts,
|
| 121 |
+
num_inference_steps=num_inference_steps,
|
| 122 |
+
output_type="pt",
|
| 123 |
+
generator=gen
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
generated_images = generated_output.images
|
| 127 |
+
generated_images_norm = (generated_images * 2) - 1
|
| 128 |
+
|
| 129 |
+
# --- CẬP NHẬT METRICS ---
|
| 130 |
+
gt_images_uint8 = to_uint8(images)
|
| 131 |
+
gt_images_lpips = to_lpips_format(images)
|
| 132 |
+
gen_images_uint8 = to_uint8(generated_images_norm)
|
| 133 |
+
gen_images_lpips = to_lpips_format(generated_images_norm)
|
| 134 |
+
|
| 135 |
+
metrics["fid"].update(gt_images_uint8, real=True)
|
| 136 |
+
metrics["fid"].update(gen_images_uint8, real=False)
|
| 137 |
+
metrics["lpips"].update(gt_images_lpips, gen_images_lpips)
|
| 138 |
+
metrics["clip_score"].update(gen_images_uint8, prompts)
|
| 139 |
+
|
| 140 |
+
# --- TRẢ VỀ KẾT QUẢ ---
|
| 141 |
+
results = {
|
| 142 |
+
"val_loss": total_val_loss / len(val_loader),
|
| 143 |
+
"fid": metrics["fid"].compute().item(),
|
| 144 |
+
"lpips": metrics["lpips"].compute().item(),
|
| 145 |
+
"clip_score": metrics["clip_score"].compute().item()
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
return results
|
| 149 |
+
|
| 150 |
+
def main():
|
| 151 |
+
print("--- Giai đoạn 2: Huấn luyện UNet-mini ---")
|
| 152 |
+
start_time_total = time.time()
|
| 153 |
+
|
| 154 |
+
# Khởi tạo Pipeline
|
| 155 |
+
pipeline = MiniDiffusionPipeline(
|
| 156 |
+
base_model_id="runwayml/stable-diffusion-v1-5",
|
| 157 |
+
device=DEVICE,
|
| 158 |
+
config_overrides=TINY_UNET_CONFIG
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Tải VAE đã fine-tune
|
| 162 |
+
try:
|
| 163 |
+
pipeline.vae.load_state_dict(torch.load(VAE_PATH, map_location=DEVICE))
|
| 164 |
+
print(f"Tải VAE đã fine-tune thành công từ {VAE_PATH}")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"Lỗi: Không thể tải VAE từ {VAE_PATH}. {e}")
|
| 167 |
+
print("Vui lòng chạy train_vae.py trước!")
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
pipeline.vae.requires_grad_(False)
|
| 171 |
+
pipeline.text_encoder.requires_grad_(False)
|
| 172 |
+
|
| 173 |
+
unet = pipeline.unet
|
| 174 |
+
vae = pipeline.vae
|
| 175 |
+
text_encoder = pipeline.text_encoder
|
| 176 |
+
tokenizer = pipeline.tokenizer
|
| 177 |
+
noise_scheduler = pipeline.noise_scheduler
|
| 178 |
+
vae_scale_factor = pipeline.config['latent_scale']
|
| 179 |
+
|
| 180 |
+
# Tải Dữ liệu
|
| 181 |
+
train_dataset = SketchDataset(TRAIN_DATA_DIR, tokenizer, IMAGE_SIZE)
|
| 182 |
+
val_dataset = SketchDataset(VAL_DATA_DIR, tokenizer, IMAGE_SIZE)
|
| 183 |
+
|
| 184 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
|
| 185 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
|
| 186 |
+
|
| 187 |
+
print(f"Đã tải {len(train_dataset)} ảnh train và {len(val_dataset)} ảnh val.")
|
| 188 |
+
|
| 189 |
+
print("Khởi tạo Evaluation Pipeline (một lần)...")
|
| 190 |
+
eval_pipeline = StableDiffusionPipeline(
|
| 191 |
+
unet=unet,
|
| 192 |
+
vae=vae,
|
| 193 |
+
text_encoder=text_encoder,
|
| 194 |
+
tokenizer=tokenizer,
|
| 195 |
+
scheduler=noise_scheduler,
|
| 196 |
+
safety_checker=None,
|
| 197 |
+
feature_extractor=None,
|
| 198 |
+
).to(DEVICE)
|
| 199 |
+
|
| 200 |
+
eval_pipeline.set_progress_bar_config(disable=True)
|
| 201 |
+
|
| 202 |
+
eval_pipeline.unet.eval()
|
| 203 |
+
eval_pipeline.vae.eval()
|
| 204 |
+
eval_pipeline.text_encoder.eval()
|
| 205 |
+
|
| 206 |
+
gen = torch.Generator(device=DEVICE).manual_seed(42)
|
| 207 |
+
|
| 208 |
+
optimizer = AdamW(unet.parameters(), lr=LEARNING_RATE)
|
| 209 |
+
lr_scheduler = get_scheduler(
|
| 210 |
+
name=pipeline.config['scheduler'],
|
| 211 |
+
optimizer=optimizer,
|
| 212 |
+
num_warmup_steps=500,
|
| 213 |
+
num_training_steps=(len(train_loader) * EPOCHS),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
metrics = {
|
| 217 |
+
"fid": FrechetInceptionDistance(feature=64).to(DEVICE),
|
| 218 |
+
"lpips": LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(DEVICE),
|
| 219 |
+
"clip_score": CLIPScore(model_name_or_path="openai/clip-vit-base-patch32").to(DEVICE)
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
start_epoch = 0
|
| 223 |
+
history = {
|
| 224 |
+
"train_loss": [], "val_loss": [],
|
| 225 |
+
"fid": [], "lpips": [], "clip_score": []
|
| 226 |
+
}
|
| 227 |
+
best_clip_score = 0.0
|
| 228 |
+
|
| 229 |
+
if os.path.exists(CHECKPOINT_PATH):
|
| 230 |
+
print(f"Phát hiện checkpoint. Đang tải từ {CHECKPOINT_PATH}...")
|
| 231 |
+
try:
|
| 232 |
+
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
|
| 233 |
+
unet.load_state_dict(checkpoint['unet_state_dict'])
|
| 234 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 235 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
|
| 236 |
+
start_epoch = checkpoint['epoch']
|
| 237 |
+
history = checkpoint['history']
|
| 238 |
+
best_clip_score = checkpoint['best_clip_score']
|
| 239 |
+
print(f"Resume training từ epoch {start_epoch}")
|
| 240 |
+
except Exception as e:
|
| 241 |
+
print(f"Lỗi khi tải checkpoint: {e}. Bắt đầu lại từ đầu.")
|
| 242 |
+
start_epoch = 0
|
| 243 |
+
history = {k: [] for k in history}
|
| 244 |
+
best_clip_score = 0.0
|
| 245 |
+
else:
|
| 246 |
+
print("Không tìm thấy checkpoint. Bắt đầu training từ đầu.")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
for epoch in range(start_epoch, EPOCHS):
|
| 250 |
+
start_time_epoch = time.time()
|
| 251 |
+
unet.train()
|
| 252 |
+
epoch_train_loss = 0.0
|
| 253 |
+
|
| 254 |
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
|
| 255 |
+
for batch in pbar:
|
| 256 |
+
images = batch["pixel_values"].to(DEVICE)
|
| 257 |
+
input_ids = batch["input_ids"].to(DEVICE)
|
| 258 |
+
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
latents = vae.encode(images).latent_dist.mean * vae_scale_factor
|
| 261 |
+
text_embeds = text_encoder(input_ids)[0]
|
| 262 |
+
|
| 263 |
+
noise = torch.randn_like(latents)
|
| 264 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=DEVICE)
|
| 265 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 266 |
+
|
| 267 |
+
noise_pred = unet(noisy_latents, timesteps, text_embeds).sample
|
| 268 |
+
loss = F.mse_loss(noise_pred, noise)
|
| 269 |
+
|
| 270 |
+
optimizer.zero_grad()
|
| 271 |
+
loss.backward()
|
| 272 |
+
optimizer.step()
|
| 273 |
+
lr_scheduler.step()
|
| 274 |
+
|
| 275 |
+
epoch_train_loss += loss.item()
|
| 276 |
+
pbar.set_postfix({"Loss": loss.item()})
|
| 277 |
+
|
| 278 |
+
avg_train_loss = epoch_train_loss / len(train_loader)
|
| 279 |
+
history["train_loss"].append(avg_train_loss)
|
| 280 |
+
|
| 281 |
+
# ---Chạy Đánh giá (Evaluation) ---
|
| 282 |
+
|
| 283 |
+
eval_results = evaluate(
|
| 284 |
+
eval_pipeline, gen, val_loader, metrics,
|
| 285 |
+
unet, vae, text_encoder, noise_scheduler,
|
| 286 |
+
vae_scale_factor, NUM_INFERENCE_STEPS
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
history["val_loss"].append(eval_results["val_loss"])
|
| 290 |
+
history["fid"].append(eval_results["fid"])
|
| 291 |
+
history["lpips"].append(eval_results["lpips"])
|
| 292 |
+
history["clip_score"].append(eval_results["clip_score"])
|
| 293 |
+
|
| 294 |
+
epoch_time_min = (time.time() - start_time_epoch) / 60
|
| 295 |
+
|
| 296 |
+
print(f"\n--- Epoch {epoch+1}/{EPOCHS} Results (Thời gian: {epoch_time_min:.2f} phút) ---")
|
| 297 |
+
print(f" Train Loss: {avg_train_loss:.6f}")
|
| 298 |
+
print(f" Val Loss: {eval_results['val_loss']:.6f}")
|
| 299 |
+
print(f" LPIPS: {eval_results['lpips']:.4f} (↓)")
|
| 300 |
+
print(f" FID: {eval_results['fid']:.4f} (↓)")
|
| 301 |
+
print(f" CLIP Score: {eval_results['clip_score']:.4f} (↑)")
|
| 302 |
+
|
| 303 |
+
if eval_results['clip_score'] > best_clip_score:
|
| 304 |
+
best_clip_score = eval_results['clip_score']
|
| 305 |
+
torch.save(unet.state_dict(), SAVE_UNET_PATH)
|
| 306 |
+
print(f"Đã lưu UNet *tốt nhất* mới tại {SAVE_UNET_PATH} (CLIP Score: {best_clip_score:.4f})")
|
| 307 |
+
|
| 308 |
+
print(f"Đang lưu checkpoint cuối cùng tại {CHECKPOINT_PATH}...")
|
| 309 |
+
checkpoint = {
|
| 310 |
+
'epoch': epoch + 1,
|
| 311 |
+
'unet_state_dict': unet.state_dict(),
|
| 312 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 313 |
+
'lr_scheduler_state_dict': lr_scheduler.state_dict(),
|
| 314 |
+
'history': history,
|
| 315 |
+
'best_clip_score': best_clip_score
|
| 316 |
+
}
|
| 317 |
+
torch.save(checkpoint, CHECKPOINT_PATH)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
total_time_min = (time.time() - start_time_total) / 60
|
| 321 |
+
print(f"\n--- Hoàn thành Giai đoạn 2 ---")
|
| 322 |
+
print(f"Tổng thời gian chạy (phiên này): {total_time_min:.2f} phút")
|
| 323 |
+
print(f"UNet đã train (tốt nhất) được lưu tại: {SAVE_UNET_PATH}")
|
| 324 |
+
|
| 325 |
+
if history['train_loss']:
|
| 326 |
+
plot_metrics(history, "unet_metrics_plot_v1.png")
|
| 327 |
+
else:
|
| 328 |
+
print("Không có dữ liệu history để vẽ biểu đồ.")
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
main()
|
Agent_Diffusion/train_vae.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from stable_diffusion import MiniDiffusionPipeline
|
| 14 |
+
from dataset import SketchDataset
|
| 15 |
+
|
| 16 |
+
# --- Cấu hình ---
|
| 17 |
+
TRAIN_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\train"
|
| 18 |
+
VAL_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\val"
|
| 19 |
+
IMAGE_SIZE = 128
|
| 20 |
+
EPOCHS = 36
|
| 21 |
+
BATCH_SIZE = 16
|
| 22 |
+
LEARNING_RATE = 1e-5
|
| 23 |
+
SAVE_PATH = "vae-finetuned.safetensors"
|
| 24 |
+
CHECKPOINT_PATH = "vae_latest_checkpoint.pth"
|
| 25 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
|
| 27 |
+
def plot_losses(train_losses, val_losses, filename="vae_loss_plot_v1.png"):
|
| 28 |
+
plt.figure(figsize=(10, 5))
|
| 29 |
+
plt.plot(train_losses, label="Train Loss")
|
| 30 |
+
plt.plot(val_losses, label="Validation Loss")
|
| 31 |
+
plt.title("VAE Fine-tuning Loss")
|
| 32 |
+
plt.xlabel("Epoch")
|
| 33 |
+
plt.ylabel("MSE Loss")
|
| 34 |
+
plt.grid()
|
| 35 |
+
plt.ylim(0.085,0.12)
|
| 36 |
+
plt.legend()
|
| 37 |
+
plt.savefig(filename)
|
| 38 |
+
print(f"Đã lưu biểu đồ loss tại {filename}")
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
print("--- Giai đoạn 1: Fine-tuning VAE ---")
|
| 42 |
+
pipeline = MiniDiffusionPipeline(
|
| 43 |
+
base_model_id="runwayml/stable-diffusion-v1-5",
|
| 44 |
+
vae_model_id="stabilityai/sd-vae-ft-mse",
|
| 45 |
+
device=DEVICE
|
| 46 |
+
)
|
| 47 |
+
vae = pipeline.vae
|
| 48 |
+
tokenizer = pipeline.tokenizer
|
| 49 |
+
vae_scale_factor = pipeline.config['latent_scale']
|
| 50 |
+
|
| 51 |
+
train_dataset = SketchDataset(TRAIN_DATA_DIR, tokenizer, IMAGE_SIZE)
|
| 52 |
+
val_dataset = SketchDataset(VAL_DATA_DIR, tokenizer, IMAGE_SIZE)
|
| 53 |
+
|
| 54 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
|
| 55 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
|
| 56 |
+
|
| 57 |
+
print(f"Đã tải {len(train_dataset)} ảnh train và {len(val_dataset)} ảnh val.")
|
| 58 |
+
|
| 59 |
+
optimizer = AdamW(vae.parameters(), lr=LEARNING_RATE)
|
| 60 |
+
|
| 61 |
+
start_epoch = 0
|
| 62 |
+
train_losses = []
|
| 63 |
+
val_losses = []
|
| 64 |
+
best_val_loss = float('inf')
|
| 65 |
+
|
| 66 |
+
if os.path.exists(CHECKPOINT_PATH):
|
| 67 |
+
print(f"Phát hiện checkpoint. Đang tải từ {CHECKPOINT_PATH}...")
|
| 68 |
+
try:
|
| 69 |
+
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
|
| 70 |
+
|
| 71 |
+
vae.load_state_dict(checkpoint['vae_state_dict'])
|
| 72 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 73 |
+
start_epoch = checkpoint['epoch'] # Đây là epoch *tiếp theo*
|
| 74 |
+
train_losses = checkpoint['train_losses']
|
| 75 |
+
val_losses = checkpoint['val_losses']
|
| 76 |
+
best_val_loss = checkpoint['best_val_loss']
|
| 77 |
+
|
| 78 |
+
print(f"Resume training từ epoch {start_epoch}")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Lỗi khi tải checkpoint: {e}. Bắt đầu lại từ đầu.")
|
| 81 |
+
start_epoch = 0
|
| 82 |
+
train_losses = []
|
| 83 |
+
val_losses = []
|
| 84 |
+
best_val_loss = float('inf')
|
| 85 |
+
else:
|
| 86 |
+
print("Không tìm thấy checkpoint. Bắt đầu training từ đầu.")
|
| 87 |
+
|
| 88 |
+
start_time = time.time()
|
| 89 |
+
|
| 90 |
+
for epoch in range(start_epoch, EPOCHS):
|
| 91 |
+
vae.train()
|
| 92 |
+
epoch_train_loss = 0.0
|
| 93 |
+
|
| 94 |
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
|
| 95 |
+
for batch in pbar:
|
| 96 |
+
images = batch["pixel_values"].to(DEVICE)
|
| 97 |
+
|
| 98 |
+
posterior = vae.encode(images).latent_dist
|
| 99 |
+
latents = posterior.mean * vae_scale_factor
|
| 100 |
+
|
| 101 |
+
reconstructions = vae.decode(latents / vae_scale_factor).sample
|
| 102 |
+
|
| 103 |
+
loss = F.mse_loss(reconstructions, images)
|
| 104 |
+
|
| 105 |
+
optimizer.zero_grad()
|
| 106 |
+
loss.backward()
|
| 107 |
+
optimizer.step()
|
| 108 |
+
|
| 109 |
+
epoch_train_loss += loss.item()
|
| 110 |
+
pbar.set_postfix({"Loss": loss.item()})
|
| 111 |
+
|
| 112 |
+
avg_train_loss = epoch_train_loss / len(train_loader)
|
| 113 |
+
train_losses.append(avg_train_loss)
|
| 114 |
+
|
| 115 |
+
vae.eval()
|
| 116 |
+
epoch_val_loss = 0.0
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
pbar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
|
| 119 |
+
for batch in pbar_val:
|
| 120 |
+
images = batch["pixel_values"].to(DEVICE)
|
| 121 |
+
posterior = vae.encode(images).latent_dist
|
| 122 |
+
latents = posterior.mean * vae_scale_factor
|
| 123 |
+
reconstructions = vae.decode(latents / vae_scale_factor).sample
|
| 124 |
+
loss = F.mse_loss(reconstructions, images)
|
| 125 |
+
epoch_val_loss += loss.item()
|
| 126 |
+
|
| 127 |
+
avg_val_loss = epoch_val_loss / len(val_loader)
|
| 128 |
+
val_losses.append(avg_val_loss)
|
| 129 |
+
|
| 130 |
+
print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.6f} - Val Loss: {avg_val_loss:.6f}")
|
| 131 |
+
|
| 132 |
+
if avg_val_loss < best_val_loss:
|
| 133 |
+
best_val_loss = avg_val_loss
|
| 134 |
+
torch.save(vae.state_dict(), SAVE_PATH)
|
| 135 |
+
print(f"Đã lưu VAE *tốt nhất* mới tại {SAVE_PATH} (Val Loss: {best_val_loss:.6f})")
|
| 136 |
+
|
| 137 |
+
print(f"Đang lưu checkpoint cuối cùng tại {CHECKPOINT_PATH}...")
|
| 138 |
+
checkpoint = {
|
| 139 |
+
'epoch': epoch + 1,
|
| 140 |
+
'vae_state_dict': vae.state_dict(),
|
| 141 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 142 |
+
'train_losses': train_losses,
|
| 143 |
+
'val_losses': val_losses,
|
| 144 |
+
'best_val_loss': best_val_loss
|
| 145 |
+
}
|
| 146 |
+
torch.save(checkpoint, CHECKPOINT_PATH)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
end_time = time.time()
|
| 150 |
+
total_time_min = (end_time - start_time) / 60
|
| 151 |
+
print(f"\n--- Hoàn thành Giai đoạn 1 ---")
|
| 152 |
+
print(f"Tổng thời gian chạy (phiên này): {total_time_min:.2f} phút")
|
| 153 |
+
print(f"VAE đã fine-tune (tốt nhất) được lưu tại: {SAVE_PATH}")
|
| 154 |
+
|
| 155 |
+
if train_losses and val_losses:
|
| 156 |
+
plot_losses(train_losses, val_losses, "vae_loss_plot.png")
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
main()
|
Agent_Diffusion/unet-mini.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:291a53f1a95a6893a7c06d9e8cbe06667b9e2bd53bcbaa87636b509ba4529821
|
| 3 |
+
size 208784907
|
Agent_Diffusion/vae-finetuned.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd73806bd9362c4d5ad89e5017178db324d5e49d2efe9e0dd04c1a2871395b32
|
| 3 |
+
size 334713859
|
Conservative Q-learning/Agen1_training.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.utils.data import DataLoader, Dataset
|
| 5 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
| 6 |
+
|
| 7 |
+
# Import từ các file bạn đã cung cấp
|
| 8 |
+
from cql_agent import CQLAgent
|
| 9 |
+
from cql_utils import DataNormalizer
|
| 10 |
+
|
| 11 |
+
# --- 1. ĐỊNH NGHĨA DATASET ---
|
| 12 |
+
class SketchOfflineDataset(Dataset):
|
| 13 |
+
def __init__(self, file_path):
|
| 14 |
+
self.data = []
|
| 15 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 16 |
+
for line in f:
|
| 17 |
+
self.data.append(json.loads(line))
|
| 18 |
+
|
| 19 |
+
def __len__(self):
|
| 20 |
+
return len(self.data)
|
| 21 |
+
|
| 22 |
+
def __getitem__(self, idx):
|
| 23 |
+
item = self.data[idx]
|
| 24 |
+
return (
|
| 25 |
+
np.array(item['state'], dtype=np.float32),
|
| 26 |
+
item['action'],
|
| 27 |
+
item['reward'],
|
| 28 |
+
np.array(item['next_state'], dtype=np.float32),
|
| 29 |
+
item['done']
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# --- 2. HÀM HUẤN LUYỆN ---
|
| 33 |
+
def train_agent_1(epoch):
|
| 34 |
+
print("🚀 Khởi động quá trình huấn luyện Agent Chính...")
|
| 35 |
+
|
| 36 |
+
# Cấu hình
|
| 37 |
+
FILE_PATH = "massive_diverse_sketch_dataset.json"
|
| 38 |
+
STATE_DIM = 768 # Tương ứng đầu ra t5-base
|
| 39 |
+
ACTION_DIM = 3 # 0: Chat, 1: Sketch, 2: Reject
|
| 40 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
+
|
| 42 |
+
# Load Dữ liệu
|
| 43 |
+
full_dataset = SketchOfflineDataset(FILE_PATH)
|
| 44 |
+
train_size = int(0.8 * len(full_dataset))
|
| 45 |
+
test_size = len(full_dataset) - train_size
|
| 46 |
+
train_dataset, _ = torch.utils.data.random_split(full_dataset, [train_size, test_size])
|
| 47 |
+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
|
| 48 |
+
|
| 49 |
+
# Khởi tạo Agent (Sử dụng Discrete CQL vì Action là 0, 1, 2)
|
| 50 |
+
agent = CQLAgent(state_dim=STATE_DIM, action_dim=ACTION_DIM, is_continuous=False, device=DEVICE)
|
| 51 |
+
|
| 52 |
+
# Chuẩn hóa dữ liệu (Normalization)
|
| 53 |
+
all_states = np.array([item['state'] for item in full_dataset.data])
|
| 54 |
+
agent.normalizer.fit(all_states)
|
| 55 |
+
|
| 56 |
+
# Vòng lặp huấn luyện
|
| 57 |
+
epochs = epoch
|
| 58 |
+
for epoch in range(epochs):
|
| 59 |
+
total_loss = 0
|
| 60 |
+
for batch in train_loader:
|
| 61 |
+
# Batch: states, actions, rewards, next_states, dones
|
| 62 |
+
metrics = agent.train_step(batch)
|
| 63 |
+
total_loss += metrics['critic_loss']
|
| 64 |
+
|
| 65 |
+
print(f"Epoch {epoch+1}/{epochs} | Critic Loss: {total_loss/len(train_loader):.4f}")
|
| 66 |
+
|
| 67 |
+
# Lưu mô hình
|
| 68 |
+
agent.save_model("saved_agent_1")
|
| 69 |
+
return agent
|
| 70 |
+
|
| 71 |
+
# --- 3. HÀM KIỂM THỬ (TEST) ---
|
| 72 |
+
class Agent1Inference:
|
| 73 |
+
def __init__(self, model_path="saved_agent_1"):
|
| 74 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 75 |
+
|
| 76 |
+
# Load T5 để phân tích prompt
|
| 77 |
+
self.tokenizer = T5Tokenizer.from_pretrained("t5-base")
|
| 78 |
+
self.encoder = T5EncoderModel.from_pretrained("t5-base").to(self.device)
|
| 79 |
+
|
| 80 |
+
# Load CQL để ra quyết định
|
| 81 |
+
self.agent = CQLAgent(state_dim=768, action_dim=3, is_continuous=False, device=self.device)
|
| 82 |
+
self.agent.load_model(model_path)
|
| 83 |
+
|
| 84 |
+
def get_action(self, text_prompt):
|
| 85 |
+
# Bước 1: T5 phân tích -> Embedding
|
| 86 |
+
inputs = self.tokenizer(text_prompt, return_tensors="pt", padding=True).to(self.device)
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
embedding = self.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy().flatten()
|
| 89 |
+
|
| 90 |
+
# Bước 2: CQL ra quyết định
|
| 91 |
+
action_idx = self.agent.select_action(embedding)
|
| 92 |
+
|
| 93 |
+
mapping = {0: "Kích hoạt Agent Giao tiếp", 1: "Kích hoạt Agent Vẽ ảnh (Sketch)", 2: "Từ chối/Yêu cầu làm rõ"}
|
| 94 |
+
return mapping[action_idx]
|
| 95 |
+
|
| 96 |
+
# --- CHẠY CHƯƠNG TRÌNH ---
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
# 1. Train
|
| 99 |
+
trained_agent = train_agent_1(1000)
|
| 100 |
+
|
| 101 |
+
# 2. Test thử mô hình đã lưu
|
| 102 |
+
print("\n--- BẮT ĐẦU TEST AGENT CHÍNH ---")
|
| 103 |
+
tester = Agent1Inference("saved_agent_1")
|
| 104 |
+
|
| 105 |
+
test_prompts = [
|
| 106 |
+
"Vẽ cho mình một bức chân dung cụ ông sketch",
|
| 107 |
+
"Chào bot, hôm nay bạn thế nào?",
|
| 108 |
+
"Hãy vẽ một bông hoa hồng bằng màu dầu rực rỡ",
|
| 109 |
+
"Kí họa nhanh khuôn mặt cô gái đang cười"
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
for p in test_prompts:
|
| 113 |
+
action = tester.get_action(p)
|
| 114 |
+
print(f"User: {p} \n=> Agent 1 quyết định: {action}\n")
|
Conservative Q-learning/cql_agent.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim as optim
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import copy
|
| 6 |
+
import os
|
| 7 |
+
from cql_utils import MLP, TanhGaussianPolicy, DataNormalizer
|
| 8 |
+
|
| 9 |
+
class CQLAgent:
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
state_dim,
|
| 13 |
+
action_dim,
|
| 14 |
+
device='cuda' if torch.cuda.is_available() else 'cpu',
|
| 15 |
+
is_continuous=True,
|
| 16 |
+
hidden_dim=256,
|
| 17 |
+
lr=3e-4,
|
| 18 |
+
cql_weight=1.0,
|
| 19 |
+
temp=1.0,
|
| 20 |
+
gamma=0.99,
|
| 21 |
+
tau=0.005
|
| 22 |
+
):
|
| 23 |
+
self.state_dim = state_dim
|
| 24 |
+
self.action_dim = action_dim
|
| 25 |
+
self.device = torch.device(device)
|
| 26 |
+
self.is_continuous = is_continuous
|
| 27 |
+
self.cql_weight = cql_weight
|
| 28 |
+
self.temp = temp # logsumexp
|
| 29 |
+
self.gamma = gamma
|
| 30 |
+
self.tau = tau # update coefficient
|
| 31 |
+
|
| 32 |
+
self.normalizer = DataNormalizer(state_dim)
|
| 33 |
+
|
| 34 |
+
if self.is_continuous:
|
| 35 |
+
self.actor = TanhGaussianPolicy(state_dim, action_dim, hidden_dim).to(self.device)
|
| 36 |
+
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
|
| 37 |
+
|
| 38 |
+
# Critic (Continuous: input = State + action)
|
| 39 |
+
# using 2 Critics to reduce overstimation (Double Q-learning)
|
| 40 |
+
self.critic_1 = MLP(state_dim + action_dim, 1, hidden_dim).to(self.device)
|
| 41 |
+
self.critic_2 = MLP(state_dim + action_dim, 1, hidden_dim).to(self.device)
|
| 42 |
+
self.target_critic_1 = copy.deepcopy(self.critic_1)
|
| 43 |
+
self.target_critic_2 = copy.deepcopy(self.critic_2)
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
# Discrete: Actor is argmax Q, Critic is Q-network (Input = state, output = Q of actions)
|
| 47 |
+
self.critic_1 = MLP(state_dim, action_dim, hidden_dim).to(self.device)
|
| 48 |
+
self.target_critic_1 = copy.deepcopy(self.critic_1)
|
| 49 |
+
# we alo can use 2 network to more robust but now we just use 1 network to check and we can update it later
|
| 50 |
+
|
| 51 |
+
params = list(self.critic_1.parameters())
|
| 52 |
+
if self.is_continuous:
|
| 53 |
+
params += list(self.critic_2.parameters())
|
| 54 |
+
self.critic_optimizer = optim.Adam(params, lr=lr)
|
| 55 |
+
|
| 56 |
+
# auto fine-tune Alpha (Entropy) for SAC (just Continuous)
|
| 57 |
+
if self.is_continuous:
|
| 58 |
+
self.target_entropy = -action_dim
|
| 59 |
+
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
|
| 60 |
+
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
|
| 61 |
+
|
| 62 |
+
def select_action(self, state, evaluate=True):
|
| 63 |
+
if not isinstance(state, np.ndarray):
|
| 64 |
+
state = np.array(state)
|
| 65 |
+
|
| 66 |
+
state = self.normalizer.normalize(state)
|
| 67 |
+
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
| 68 |
+
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
if self.is_continuous:
|
| 71 |
+
action, _ = self.actor(state)
|
| 72 |
+
return action.cpu().data.numpy().flatten()
|
| 73 |
+
else:
|
| 74 |
+
q_values = self.critic_1(state)
|
| 75 |
+
return q_values.argmax(dim=1).cpu().data.numpy().item()
|
| 76 |
+
|
| 77 |
+
def _compute_cql_loss(self, q_values_pred, states, actions):
|
| 78 |
+
if self.is_continuous:
|
| 79 |
+
batch_size = states.shape[0]
|
| 80 |
+
num_samples = 10
|
| 81 |
+
|
| 82 |
+
# Taking random samples
|
| 83 |
+
random_actions = torch.FloatTensor(batch_size * num_samples, self.action_dim).uniform_(-1, 1).to(self.device)
|
| 84 |
+
|
| 85 |
+
# Taking actions from current Policy
|
| 86 |
+
curr_actions, curr_log_pi = self.actor(states.repeat_interleave(num_samples, dim=0))
|
| 87 |
+
|
| 88 |
+
# Calculating Q for each these samples
|
| 89 |
+
states_repeated = states.repeat_interleave(num_samples, dim=0)
|
| 90 |
+
|
| 91 |
+
# grafting states and random actions and current actions
|
| 92 |
+
q1_rand = self.critic_1(torch.cat([states_repeated, random_actions], dim=1))
|
| 93 |
+
q1_curr = self.critic_1(torch.cat([states_repeated, curr_actions], dim=1))
|
| 94 |
+
|
| 95 |
+
# gathering: [batch, num_samples, 1]
|
| 96 |
+
q1_rand = q1_rand.view(batch_size, num_samples, 1)
|
| 97 |
+
q1_curr = q1_curr.view(batch_size, num_samples, 1)
|
| 98 |
+
|
| 99 |
+
# merging to calculating LogSumExp
|
| 100 |
+
cat_q1 = torch.cat([q1_rand, q1_curr], dim=1)
|
| 101 |
+
|
| 102 |
+
# CQL Loss 1: log(sum(exp(Q_ood)))
|
| 103 |
+
cql_loss_1 = torch.logsumexp(cat_q1 / self.temp, dim=1).mean() * self.temp
|
| 104 |
+
# CQL Loss 2: - Q_data (Maximizing Q-value of data samples)
|
| 105 |
+
cql_loss_2 = q_values_pred.mean()
|
| 106 |
+
|
| 107 |
+
return (cql_loss_1 - cql_loss_2) * self.cql_weight
|
| 108 |
+
|
| 109 |
+
else:
|
| 110 |
+
# q_values_pred shape: [batch, action_dim] (calculated for all action space)
|
| 111 |
+
|
| 112 |
+
# CQL Loss 1: log(sum(exp(Q_all))) - calculate for all action space
|
| 113 |
+
cql_loss_1 = torch.logsumexp(q_values_pred / self.temp, dim=1).mean() * self.temp
|
| 114 |
+
|
| 115 |
+
# CQL Loss 2: - Q_data (takeing Q-value at real action in batch)
|
| 116 |
+
# actions shape: [batch, 1]
|
| 117 |
+
q_data = q_values_pred.gather(1, actions.long())
|
| 118 |
+
cql_loss_2 = q_data.mean()
|
| 119 |
+
|
| 120 |
+
return (cql_loss_1 - cql_loss_2) * self.cql_weight
|
| 121 |
+
|
| 122 |
+
def train_step(self, batch):
|
| 123 |
+
states, actions, rewards, next_states, dones = batch
|
| 124 |
+
|
| 125 |
+
states = torch.FloatTensor(states).to(self.device)
|
| 126 |
+
actions = torch.FloatTensor(actions).to(self.device)
|
| 127 |
+
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
|
| 128 |
+
next_states = torch.FloatTensor(next_states).to(self.device)
|
| 129 |
+
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
|
| 130 |
+
|
| 131 |
+
# Update Critic (Q-Functions)
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
if self.is_continuous:
|
| 134 |
+
next_actions, next_log_pi = self.actor(next_states)
|
| 135 |
+
q1_target = self.target_critic_1(torch.cat([next_states, next_actions], dim=1))
|
| 136 |
+
q2_target = self.target_critic_2(torch.cat([next_states, next_actions], dim=1))
|
| 137 |
+
min_q_target = torch.min(q1_target, q2_target)
|
| 138 |
+
|
| 139 |
+
# Soft Actor-Critic Target (Having entropy)
|
| 140 |
+
alpha = self.log_alpha.exp()
|
| 141 |
+
q_target = rewards + (1 - dones) * self.gamma * (min_q_target - alpha * next_log_pi)
|
| 142 |
+
else:
|
| 143 |
+
# DQN Target
|
| 144 |
+
# Double DQN logic
|
| 145 |
+
q_next = self.target_critic_1(next_states)
|
| 146 |
+
max_q_next, _ = torch.max(q_next, dim=1, keepdim=True)
|
| 147 |
+
q_target = rewards + (1 - dones) * self.gamma * max_q_next
|
| 148 |
+
|
| 149 |
+
# calculating Q hiện tại
|
| 150 |
+
if self.is_continuous:
|
| 151 |
+
q1_pred = self.critic_1(torch.cat([states, actions], dim=1))
|
| 152 |
+
q2_pred = self.critic_2(torch.cat([states, actions], dim=1))
|
| 153 |
+
mse_loss = F.mse_loss(q1_pred, q_target) + F.mse_loss(q2_pred, q_target)
|
| 154 |
+
|
| 155 |
+
# adding CQL Loss
|
| 156 |
+
cql_loss = self._compute_cql_loss(q1_pred, states, actions) + \
|
| 157 |
+
self._compute_cql_loss(q2_pred, states, actions)
|
| 158 |
+
else:
|
| 159 |
+
q_all = self.critic_1(states)
|
| 160 |
+
q_pred = q_all.gather(1, actions.long())
|
| 161 |
+
mse_loss = F.mse_loss(q_pred, q_target)
|
| 162 |
+
|
| 163 |
+
# Thêm CQL Loss
|
| 164 |
+
cql_loss = self._compute_cql_loss(q_all, states, actions)
|
| 165 |
+
|
| 166 |
+
total_critic_loss = mse_loss + cql_loss
|
| 167 |
+
|
| 168 |
+
self.critic_optimizer.zero_grad()
|
| 169 |
+
total_critic_loss.backward()
|
| 170 |
+
self.critic_optimizer.step()
|
| 171 |
+
|
| 172 |
+
# Update Actor (just Continuous)
|
| 173 |
+
actor_loss_val = 0
|
| 174 |
+
if self.is_continuous:
|
| 175 |
+
new_actions, log_pi = self.actor(states)
|
| 176 |
+
q1_new = self.critic_1(torch.cat([states, new_actions], dim=1))
|
| 177 |
+
q2_new = self.critic_2(torch.cat([states, new_actions], dim=1))
|
| 178 |
+
min_q_new = torch.min(q1_new, q2_new)
|
| 179 |
+
|
| 180 |
+
# SAC Actor Loss: Maximize (Q - alpha * log_prob) -> Minimize (alpha * log_prob - Q)
|
| 181 |
+
actor_loss = (alpha * log_pi - min_q_new).mean()
|
| 182 |
+
|
| 183 |
+
self.actor_optimizer.zero_grad()
|
| 184 |
+
actor_loss.backward()
|
| 185 |
+
self.actor_optimizer.step()
|
| 186 |
+
|
| 187 |
+
# Update Alpha (Temperature)
|
| 188 |
+
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
|
| 189 |
+
self.alpha_optimizer.zero_grad()
|
| 190 |
+
alpha_loss.backward()
|
| 191 |
+
self.alpha_optimizer.step()
|
| 192 |
+
|
| 193 |
+
actor_loss_val = actor_loss.item()
|
| 194 |
+
|
| 195 |
+
# Soft Update Target Networks ===
|
| 196 |
+
self._soft_update(self.critic_1, self.target_critic_1)
|
| 197 |
+
if self.is_continuous:
|
| 198 |
+
self._soft_update(self.critic_2, self.target_critic_2)
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
"critic_loss": total_critic_loss.item(),
|
| 202 |
+
"cql_loss": cql_loss.item(),
|
| 203 |
+
"actor_loss": actor_loss_val
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
def _soft_update(self, local_model, target_model):
|
| 207 |
+
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
|
| 208 |
+
target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)
|
| 209 |
+
|
| 210 |
+
def save_model(self, path):
|
| 211 |
+
os.makedirs(path, exist_ok=True)
|
| 212 |
+
state_dict = {
|
| 213 |
+
'is_continuous': self.is_continuous,
|
| 214 |
+
'critic_1': self.critic_1.state_dict(),
|
| 215 |
+
}
|
| 216 |
+
if self.is_continuous:
|
| 217 |
+
state_dict.update({
|
| 218 |
+
'critic_2': self.critic_2.state_dict(),
|
| 219 |
+
'actor': self.actor.state_dict(),
|
| 220 |
+
'log_alpha': self.log_alpha
|
| 221 |
+
})
|
| 222 |
+
|
| 223 |
+
torch.save(state_dict, os.path.join(path, "cql_model.pth"))
|
| 224 |
+
self.normalizer.save(os.path.join(path, "normalizer.pkl"))
|
| 225 |
+
print(f"Model saved to {path}")
|
| 226 |
+
|
| 227 |
+
def load_model(self, path):
|
| 228 |
+
model_path = os.path.join(path, "cql_model.pth")
|
| 229 |
+
if not os.path.exists(model_path):
|
| 230 |
+
print("No model found!")
|
| 231 |
+
return
|
| 232 |
+
|
| 233 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
| 234 |
+
self.critic_1.load_state_dict(checkpoint['critic_1'])
|
| 235 |
+
|
| 236 |
+
if self.is_continuous and checkpoint['is_continuous']:
|
| 237 |
+
self.critic_2.load_state_dict(checkpoint['critic_2'])
|
| 238 |
+
self.actor.load_state_dict(checkpoint['actor'])
|
| 239 |
+
self.log_alpha = checkpoint['log_alpha']
|
| 240 |
+
|
| 241 |
+
self.normalizer.load(os.path.join(path, "normalizer.pkl"))
|
| 242 |
+
print(f"Model loaded from {path}")
|
Conservative Q-learning/cql_utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.distributions import Normal
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pickle
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
class DataNormalizer:
|
| 10 |
+
def __init__(self, state_dim):
|
| 11 |
+
self.mean = np.zeros(state_dim)
|
| 12 |
+
self.std = np.zeros(state_dim)
|
| 13 |
+
self.std[self.std < 1e-6] = 1.0
|
| 14 |
+
|
| 15 |
+
def fit(self, states):
|
| 16 |
+
self.mean = np.mean(states, axis=0)
|
| 17 |
+
self.std = np.std(states, axis=0)
|
| 18 |
+
|
| 19 |
+
def normalize(self, states):
|
| 20 |
+
return (states - self.mean) / self.std
|
| 21 |
+
|
| 22 |
+
def denormalize(self, states):
|
| 23 |
+
return states * self.std + self.mean
|
| 24 |
+
|
| 25 |
+
def save(self, path):
|
| 26 |
+
with open(path, 'wb') as f:
|
| 27 |
+
pickle.dump({'mean': self.mean, 'std': self.std}, f)
|
| 28 |
+
|
| 29 |
+
def load(self, path):
|
| 30 |
+
with open(path, 'rb') as f:
|
| 31 |
+
data = pickle.load(f)
|
| 32 |
+
self.mean = data['mean']
|
| 33 |
+
self.std = data['std']
|
| 34 |
+
|
| 35 |
+
class MLP(nn.Module):
|
| 36 |
+
def __init__(self, input_dim, output_dim, hidden_dim=256, n_layers=2):
|
| 37 |
+
super().__init__()
|
| 38 |
+
layers = []
|
| 39 |
+
layers.append(nn.Linear(input_dim, hidden_dim))
|
| 40 |
+
layers.append(nn.ReLU())
|
| 41 |
+
for _ in range(n_layers - 1):
|
| 42 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
| 43 |
+
layers.append(nn.ReLU())
|
| 44 |
+
|
| 45 |
+
layers.append(nn.Linear(hidden_dim, output_dim))
|
| 46 |
+
self.net = nn.Sequential(*layers)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return self.net(x)
|
| 50 |
+
|
| 51 |
+
class TanhGaussianPolicy(nn.Module):
|
| 52 |
+
def __init__(self, state_dim, action_dim, hidden_dim=256):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.base = MLP(state_dim, hidden_dim, hidden_dim)
|
| 55 |
+
self.mu_head = nn.Linear(hidden_dim, action_dim)
|
| 56 |
+
self.log_std_head = nn.Linear(hidden_dim, action_dim)
|
| 57 |
+
|
| 58 |
+
def forward(self, state):
|
| 59 |
+
# x = self.base.net[:-1](state) # getting feature from MLP base
|
| 60 |
+
x = self.base(state)
|
| 61 |
+
x = F.relu(x)
|
| 62 |
+
|
| 63 |
+
mu = self.mu_head(x)
|
| 64 |
+
log_std = self.log_std_head(x)
|
| 65 |
+
log_std = torch.clamp(log_std, -20, 2)
|
| 66 |
+
|
| 67 |
+
std = torch.exp(log_std)
|
| 68 |
+
dist = Normal(mu, std)
|
| 69 |
+
|
| 70 |
+
# Reparameterization trick: a = mu + std * epsilon
|
| 71 |
+
x_t = dist.rsample()
|
| 72 |
+
action = torch.tanh(x_t) # force to [-1, 1]
|
| 73 |
+
|
| 74 |
+
#calculating log probability
|
| 75 |
+
log_prob = dist.log_prob(x_t)
|
| 76 |
+
|
| 77 |
+
log_prob -= torch.log(1 - action.pow(2) + 1e-6)
|
| 78 |
+
log_prob = log_prob.sum(1, keepdim=True)
|
| 79 |
+
|
| 80 |
+
return action, log_prob
|
| 81 |
+
|
| 82 |
+
|
Conservative Q-learning/saved_agent_1/cql_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8ff7062d0d7b6d0ff647071ef0d21ccb7121770d6c51aeb39ed8ad10f882758
|
| 3 |
+
size 1056877
|
Conservative Q-learning/saved_agent_1/normalizer.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ac608983b3f70024d8ec8b7efd1e58239d70a56f0c926d091a6a5267b56adcb3
|
| 3 |
+
size 12491
|
app.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit App - Beautiful UI for CQL Chatbot System
|
| 3 |
+
"""
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import config
|
| 8 |
+
from chatbot_engine import CQLChatbot
|
| 9 |
+
from memory_manager import MemoryManager
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Page configuration
|
| 13 |
+
st.set_page_config(
|
| 14 |
+
page_title=config.APP_TITLE,
|
| 15 |
+
page_icon=config.APP_ICON,
|
| 16 |
+
layout="wide",
|
| 17 |
+
initial_sidebar_state="expanded"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Custom CSS for beautiful Dark Galaxy UI
|
| 21 |
+
st.markdown("""
|
| 22 |
+
<style>
|
| 23 |
+
/* Main container - Dark Galaxy Background */
|
| 24 |
+
.main {
|
| 25 |
+
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
/* Chat container */
|
| 29 |
+
.stChatMessage {
|
| 30 |
+
background-color: rgba(30, 30, 46, 0.8);
|
| 31 |
+
border-radius: 16px;
|
| 32 |
+
padding: 18px;
|
| 33 |
+
margin: 12px 0;
|
| 34 |
+
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
| 35 |
+
backdrop-filter: blur(10px);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
/* User message - Purple gradient */
|
| 39 |
+
.stChatMessage[data-testid="user-message"] {
|
| 40 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 41 |
+
color: white;
|
| 42 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/* Assistant message - Dark with blue accent */
|
| 46 |
+
.stChatMessage[data-testid="assistant-message"] {
|
| 47 |
+
background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%);
|
| 48 |
+
border-left: 4px solid #667eea;
|
| 49 |
+
color: #e2e8f0;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
/* Sidebar - Dark theme */
|
| 53 |
+
section[data-testid="stSidebar"] {
|
| 54 |
+
background: linear-gradient(180deg, #1a1a2e 0%, #16213e 100%);
|
| 55 |
+
border-right: 1px solid rgba(102, 126, 234, 0.2);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
section[data-testid="stSidebar"] * {
|
| 59 |
+
color: #e2e8f0 !important;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
/* Headers - Light text */
|
| 63 |
+
h1, h2, h3, h4, h5, h6 {
|
| 64 |
+
color: #f7fafc !important;
|
| 65 |
+
text-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
/* Paragraph text */
|
| 69 |
+
p, span, div {
|
| 70 |
+
color: #cbd5e0;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/* Buttons - Purple gradient */
|
| 74 |
+
.stButton>button {
|
| 75 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 76 |
+
color: white;
|
| 77 |
+
border: none;
|
| 78 |
+
border-radius: 12px;
|
| 79 |
+
padding: 12px 24px;
|
| 80 |
+
font-weight: 600;
|
| 81 |
+
transition: all 0.3s ease;
|
| 82 |
+
box-shadow: 0 4px 8px rgba(102, 126, 234, 0.3);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
.stButton>button:hover {
|
| 86 |
+
transform: translateY(-2px);
|
| 87 |
+
box-shadow: 0 6px 16px rgba(102, 126, 234, 0.5);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
/* Input box - Dark with glow */
|
| 91 |
+
.stTextInput>div>div>input, .stChatInput>div>div>input {
|
| 92 |
+
border-radius: 12px;
|
| 93 |
+
border: 2px solid rgba(102, 126, 234, 0.3);
|
| 94 |
+
padding: 14px;
|
| 95 |
+
background-color: rgba(30, 30, 46, 0.6);
|
| 96 |
+
color: #e2e8f0;
|
| 97 |
+
transition: all 0.3s ease;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.stTextInput>div>div>input:focus, .stChatInput>div>div>input:focus {
|
| 101 |
+
border-color: #667eea;
|
| 102 |
+
box-shadow: 0 0 12px rgba(102, 126, 234, 0.4);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
/* Slider */
|
| 106 |
+
.stSlider>div>div>div {
|
| 107 |
+
background-color: rgba(102, 126, 234, 0.3);
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
/* Metrics - Dark cards */
|
| 111 |
+
.stMetric {
|
| 112 |
+
background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%);
|
| 113 |
+
padding: 18px;
|
| 114 |
+
border-radius: 12px;
|
| 115 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.3);
|
| 116 |
+
border: 1px solid rgba(102, 126, 234, 0.2);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
.stMetric label {
|
| 120 |
+
color: #a0aec0 !important;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
.stMetric [data-testid="stMetricValue"] {
|
| 124 |
+
color: #667eea !important;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
/* Action badges - Glowing */
|
| 128 |
+
.action-badge {
|
| 129 |
+
display: inline-block;
|
| 130 |
+
padding: 6px 16px;
|
| 131 |
+
border-radius: 20px;
|
| 132 |
+
font-size: 0.85em;
|
| 133 |
+
font-weight: 600;
|
| 134 |
+
margin-left: 8px;
|
| 135 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
.action-0 {
|
| 139 |
+
background: linear-gradient(135deg, #4299e1 0%, #3182ce 100%);
|
| 140 |
+
color: white;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
.action-1 {
|
| 144 |
+
background: linear-gradient(135deg, #9f7aea 0%, #805ad5 100%);
|
| 145 |
+
color: white;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.action-2 {
|
| 149 |
+
background: linear-gradient(135deg, #ed8936 0%, #dd6b20 100%);
|
| 150 |
+
color: white;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/* Divider */
|
| 154 |
+
hr {
|
| 155 |
+
border-color: rgba(102, 126, 234, 0.2);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
/* Expander */
|
| 159 |
+
.streamlit-expanderHeader {
|
| 160 |
+
background-color: rgba(30, 30, 46, 0.6);
|
| 161 |
+
color: #e2e8f0 !important;
|
| 162 |
+
border-radius: 8px;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
/* Caption text */
|
| 166 |
+
.stCaption {
|
| 167 |
+
color: #a0aec0 !important;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/* Success/Info/Warning boxes */
|
| 171 |
+
.stSuccess, .stInfo, .stWarning {
|
| 172 |
+
background-color: rgba(30, 30, 46, 0.8);
|
| 173 |
+
border-radius: 8px;
|
| 174 |
+
color: #e2e8f0;
|
| 175 |
+
}
|
| 176 |
+
</style>
|
| 177 |
+
""", unsafe_allow_html=True)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# Initialize session state
|
| 181 |
+
def init_session_state():
|
| 182 |
+
"""Initialize Streamlit session state"""
|
| 183 |
+
# Initialize memory manager
|
| 184 |
+
if 'memory_manager' not in st.session_state:
|
| 185 |
+
st.session_state.memory_manager = MemoryManager()
|
| 186 |
+
st.session_state.memory_manager.create_new_session()
|
| 187 |
+
|
| 188 |
+
if 'chatbot' not in st.session_state:
|
| 189 |
+
with st.spinner('🚀 Đang khởi tạo CQL Chatbot System...'):
|
| 190 |
+
st.session_state.chatbot = CQLChatbot(
|
| 191 |
+
memory_manager=st.session_state.memory_manager
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if 'messages' not in st.session_state:
|
| 195 |
+
st.session_state.messages = []
|
| 196 |
+
|
| 197 |
+
if 'action_history' not in st.session_state:
|
| 198 |
+
st.session_state.action_history = []
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def display_action_badge(action: int, action_name: str):
|
| 202 |
+
"""Display action badge with color coding"""
|
| 203 |
+
badge_html = f'<span class="action-badge action-{action}">{config.ACTION_DESCRIPTIONS[action]}</span>'
|
| 204 |
+
return badge_html
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def create_action_distribution_chart(distribution: dict, chart_id: str = "action_dist"):
|
| 208 |
+
"""Create a pie chart for action distribution"""
|
| 209 |
+
if not distribution or sum(distribution.values()) == 0:
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
labels = list(distribution.keys())
|
| 213 |
+
values = list(distribution.values())
|
| 214 |
+
|
| 215 |
+
fig = go.Figure(data=[go.Pie(
|
| 216 |
+
labels=labels,
|
| 217 |
+
values=values,
|
| 218 |
+
hole=0.4,
|
| 219 |
+
marker=dict(colors=['#4A90E2', '#7b1fa2', '#f57c00']),
|
| 220 |
+
textinfo='label+percent',
|
| 221 |
+
textfont=dict(size=12)
|
| 222 |
+
)])
|
| 223 |
+
|
| 224 |
+
fig.update_layout(
|
| 225 |
+
title="Phân bố hành động",
|
| 226 |
+
height=300,
|
| 227 |
+
showlegend=True,
|
| 228 |
+
margin=dict(l=20, r=20, t=40, b=20),
|
| 229 |
+
# Add unique identifier
|
| 230 |
+
updatemenus=[],
|
| 231 |
+
sliders=[]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return fig
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def create_q_values_chart(q_values: list, chart_id: str = "q_values"):
|
| 238 |
+
"""Create a bar chart for Q-values"""
|
| 239 |
+
actions = [config.ACTION_DESCRIPTIONS[i] for i in range(len(q_values))]
|
| 240 |
+
|
| 241 |
+
fig = go.Figure(data=[go.Bar(
|
| 242 |
+
x=actions,
|
| 243 |
+
y=q_values,
|
| 244 |
+
marker=dict(
|
| 245 |
+
color=q_values,
|
| 246 |
+
colorscale='Viridis',
|
| 247 |
+
showscale=True
|
| 248 |
+
),
|
| 249 |
+
text=[f'{v:.2f}' for v in q_values],
|
| 250 |
+
textposition='auto',
|
| 251 |
+
)])
|
| 252 |
+
|
| 253 |
+
fig.update_layout(
|
| 254 |
+
title="Q-Values (Giá trị hành động)",
|
| 255 |
+
xaxis_title="Hành động",
|
| 256 |
+
yaxis_title="Q-Value",
|
| 257 |
+
height=300,
|
| 258 |
+
showlegend=False,
|
| 259 |
+
margin=dict(l=20, r=20, t=40, b=20),
|
| 260 |
+
# Add unique identifier
|
| 261 |
+
updatemenus=[],
|
| 262 |
+
sliders=[]
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return fig
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def main():
|
| 269 |
+
"""Main Streamlit app"""
|
| 270 |
+
|
| 271 |
+
# Initialize
|
| 272 |
+
init_session_state()
|
| 273 |
+
|
| 274 |
+
# Fixed temperature (không hiển thị cho user)
|
| 275 |
+
FIXED_TEMPERATURE = 0.7
|
| 276 |
+
|
| 277 |
+
# Header
|
| 278 |
+
st.title(f"{config.APP_ICON} Chatbot Wisdom")
|
| 279 |
+
st.markdown("### 🧠 Hệ thống Multi-Agent với Conservative Q-Learning")
|
| 280 |
+
st.markdown("---")
|
| 281 |
+
|
| 282 |
+
# Sidebar
|
| 283 |
+
with st.sidebar:
|
| 284 |
+
st.header(config.SIDEBAR_TITLE)
|
| 285 |
+
|
| 286 |
+
# New chat button
|
| 287 |
+
if st.button("🆕 Cuộc trò chuyện mới", use_container_width=True):
|
| 288 |
+
st.session_state.messages = []
|
| 289 |
+
st.session_state.action_history = []
|
| 290 |
+
st.session_state.chatbot.clear_history()
|
| 291 |
+
# Create new session in memory
|
| 292 |
+
st.session_state.memory_manager.create_new_session()
|
| 293 |
+
st.success("✅ Đã tạo cuộc trò chuyện mới!")
|
| 294 |
+
st.rerun()
|
| 295 |
+
|
| 296 |
+
st.divider()
|
| 297 |
+
|
| 298 |
+
# Session history
|
| 299 |
+
st.subheader("💾 Lịch sử hội thoại")
|
| 300 |
+
sessions = st.session_state.memory_manager.get_all_sessions()
|
| 301 |
+
|
| 302 |
+
if sessions:
|
| 303 |
+
st.caption(f"Tổng cộng: {len(sessions)} phiên")
|
| 304 |
+
|
| 305 |
+
# Show last 5 sessions
|
| 306 |
+
for session in sessions[:5]:
|
| 307 |
+
session_id = session['session_id']
|
| 308 |
+
created_at = datetime.fromisoformat(session['created_at']).strftime("%d/%m/%Y %H:%M")
|
| 309 |
+
msg_count = session['message_count']
|
| 310 |
+
|
| 311 |
+
col1, col2 = st.columns([3, 1])
|
| 312 |
+
with col1:
|
| 313 |
+
st.caption(f"📝 {created_at} ({msg_count} tin nhắn)")
|
| 314 |
+
with col2:
|
| 315 |
+
if st.button("🗑️", key=f"del_{session_id}"):
|
| 316 |
+
st.session_state.memory_manager.delete_session(session_id)
|
| 317 |
+
st.rerun()
|
| 318 |
+
else:
|
| 319 |
+
st.info("Chưa có lịch sử")
|
| 320 |
+
|
| 321 |
+
st.divider()
|
| 322 |
+
|
| 323 |
+
# Agent status
|
| 324 |
+
st.subheader("🤖 Trạng thái Agent")
|
| 325 |
+
st.success("✅ CQL Agent: Hoạt động")
|
| 326 |
+
st.success("✅ Communication Agent: Hoạt động")
|
| 327 |
+
st.success("✅ Drawing Agent: Hoạt động")
|
| 328 |
+
|
| 329 |
+
st.divider()
|
| 330 |
+
|
| 331 |
+
# Statistics
|
| 332 |
+
st.subheader("📊 Thống kê")
|
| 333 |
+
st.metric("Số tin nhắn", len(st.session_state.messages))
|
| 334 |
+
|
| 335 |
+
# Action distribution
|
| 336 |
+
if st.session_state.action_history:
|
| 337 |
+
distribution = st.session_state.chatbot.get_action_distribution()
|
| 338 |
+
fig = create_action_distribution_chart(distribution, chart_id=f"dist_{len(st.session_state.action_history)}")
|
| 339 |
+
if fig:
|
| 340 |
+
st.plotly_chart(fig, use_container_width=True, key=f"action_dist_{len(st.session_state.action_history)}")
|
| 341 |
+
|
| 342 |
+
st.divider()
|
| 343 |
+
|
| 344 |
+
# Info
|
| 345 |
+
st.subheader("ℹ️ Thông tin")
|
| 346 |
+
st.info(f"""
|
| 347 |
+
**Model**: CQL Agent
|
| 348 |
+
**State Dim**: {config.STATE_DIM}
|
| 349 |
+
**Actions**: {config.ACTION_DIM}
|
| 350 |
+
**Device**: {config.DEVICE}
|
| 351 |
+
""")
|
| 352 |
+
|
| 353 |
+
# Main chat area
|
| 354 |
+
chat_container = st.container()
|
| 355 |
+
|
| 356 |
+
# Display chat messages
|
| 357 |
+
with chat_container:
|
| 358 |
+
for idx, message in enumerate(st.session_state.messages):
|
| 359 |
+
with st.chat_message(message["role"]):
|
| 360 |
+
st.markdown(message["content"])
|
| 361 |
+
|
| 362 |
+
# Display action badge for assistant messages
|
| 363 |
+
if message["role"] == "assistant" and "action" in message:
|
| 364 |
+
action = message["action"]
|
| 365 |
+
action_name = message["action_name"]
|
| 366 |
+
st.markdown(
|
| 367 |
+
display_action_badge(action, action_name),
|
| 368 |
+
unsafe_allow_html=True
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Display Q-values chart
|
| 372 |
+
if "q_values" in message:
|
| 373 |
+
with st.expander("📊 Xem Q-Values"):
|
| 374 |
+
fig = create_q_values_chart(message["q_values"], chart_id=f"qval_{idx}")
|
| 375 |
+
st.plotly_chart(fig, use_container_width=True, key=f"q_values_{idx}")
|
| 376 |
+
|
| 377 |
+
# Display image if available
|
| 378 |
+
if message.get("image_path"):
|
| 379 |
+
st.image(message["image_path"])
|
| 380 |
+
|
| 381 |
+
# Chat input
|
| 382 |
+
if prompt := st.chat_input("💬 Nhập tin nhắn của bạn..."):
|
| 383 |
+
# Add user message
|
| 384 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 385 |
+
|
| 386 |
+
# Display user message
|
| 387 |
+
with st.chat_message("user"):
|
| 388 |
+
st.markdown(prompt)
|
| 389 |
+
|
| 390 |
+
# Generate response với fixed temperature
|
| 391 |
+
with st.chat_message("assistant"):
|
| 392 |
+
with st.spinner("🤔 Đang suy nghĩ..."):
|
| 393 |
+
response_data = st.session_state.chatbot.chat(prompt, FIXED_TEMPERATURE)
|
| 394 |
+
|
| 395 |
+
# Display response
|
| 396 |
+
st.markdown(response_data['response'])
|
| 397 |
+
|
| 398 |
+
# Display action badge
|
| 399 |
+
st.markdown(
|
| 400 |
+
display_action_badge(response_data['action'], response_data['action_name']),
|
| 401 |
+
unsafe_allow_html=True
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Display Q-values
|
| 405 |
+
with st.expander("📊 Xem Q-Values"):
|
| 406 |
+
msg_idx = len(st.session_state.messages)
|
| 407 |
+
fig = create_q_values_chart(response_data['q_values'], chart_id=f"qval_new_{msg_idx}")
|
| 408 |
+
st.plotly_chart(fig, use_container_width=True, key=f"q_values_new_{msg_idx}")
|
| 409 |
+
|
| 410 |
+
# Display image if available
|
| 411 |
+
if response_data.get('image_path'):
|
| 412 |
+
st.image(response_data['image_path'])
|
| 413 |
+
|
| 414 |
+
# Add assistant message to history
|
| 415 |
+
st.session_state.messages.append({
|
| 416 |
+
"role": "assistant",
|
| 417 |
+
"content": response_data['response'],
|
| 418 |
+
"action": response_data['action'],
|
| 419 |
+
"action_name": response_data['action_name'],
|
| 420 |
+
"q_values": response_data['q_values'],
|
| 421 |
+
"image_path": response_data.get('image_path')
|
| 422 |
+
})
|
| 423 |
+
|
| 424 |
+
# Update action history
|
| 425 |
+
st.session_state.action_history.append(response_data['action'])
|
| 426 |
+
|
| 427 |
+
# Rerun to update UI
|
| 428 |
+
st.rerun()
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if __name__ == "__main__":
|
| 432 |
+
main()
|
chatbot_engine.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CQL Chatbot Engine - Main chatbot logic integrating CQL agent with multi-agent system
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'Conservative Q-learning'))
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
| 11 |
+
from typing import Dict, List, Tuple
|
| 12 |
+
import config
|
| 13 |
+
from memory_manager import MemoryManager
|
| 14 |
+
from cql_agent import CQLAgent
|
| 15 |
+
from communication_agent import CommunicationAgent
|
| 16 |
+
from drawing_agent import DrawingAgent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CQLChatbot:
|
| 20 |
+
def __init__(self, model_path: str = None, memory_manager: MemoryManager = None):
|
| 21 |
+
"""
|
| 22 |
+
Initialize CQL Chatbot with all components
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_path: Path to saved CQL model
|
| 26 |
+
memory_manager: Memory manager instance for conversation storage
|
| 27 |
+
"""
|
| 28 |
+
print("🚀 Initializing CQL Chatbot System...")
|
| 29 |
+
|
| 30 |
+
# Set device
|
| 31 |
+
self.device = torch.device(config.DEVICE if torch.cuda.is_available() else 'cpu')
|
| 32 |
+
print(f"📱 Using device: {self.device}")
|
| 33 |
+
|
| 34 |
+
# Load T5 Encoder for text embedding
|
| 35 |
+
print("📚 Loading T5 encoder...")
|
| 36 |
+
self.tokenizer = T5Tokenizer.from_pretrained(config.T5_MODEL_NAME)
|
| 37 |
+
self.encoder = T5EncoderModel.from_pretrained(config.T5_MODEL_NAME).to(self.device)
|
| 38 |
+
self.encoder.eval() # Set to evaluation mode
|
| 39 |
+
print("✅ T5 encoder loaded")
|
| 40 |
+
|
| 41 |
+
# Load CQL Agent (Decision Maker)
|
| 42 |
+
print("🧠 Loading CQL agent...")
|
| 43 |
+
self.cql_agent = CQLAgent(
|
| 44 |
+
state_dim=config.STATE_DIM,
|
| 45 |
+
action_dim=config.ACTION_DIM,
|
| 46 |
+
is_continuous=False,
|
| 47 |
+
device=self.device
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Load trained model
|
| 51 |
+
model_path = model_path or str(config.MODEL_PATH)
|
| 52 |
+
self.cql_agent.load_model(model_path)
|
| 53 |
+
print("✅ CQL agent loaded")
|
| 54 |
+
|
| 55 |
+
# Initialize sub-agents
|
| 56 |
+
print("👥 Initializing sub-agents...")
|
| 57 |
+
self.communication_agent = CommunicationAgent()
|
| 58 |
+
self.drawing_agent = DrawingAgent()
|
| 59 |
+
print("✅ All agents initialized")
|
| 60 |
+
|
| 61 |
+
# Memory manager
|
| 62 |
+
self.memory_manager = memory_manager
|
| 63 |
+
if self.memory_manager:
|
| 64 |
+
print("💾 Memory manager enabled")
|
| 65 |
+
|
| 66 |
+
# Conversation history
|
| 67 |
+
self.conversation_history = []
|
| 68 |
+
|
| 69 |
+
print("🎉 CQL Chatbot System ready!\n")
|
| 70 |
+
|
| 71 |
+
def encode_text(self, text: str) -> np.ndarray:
|
| 72 |
+
"""
|
| 73 |
+
Encode text into T5 embedding
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
text: Input text
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Embedding vector (768-dim)
|
| 80 |
+
"""
|
| 81 |
+
inputs = self.tokenizer(
|
| 82 |
+
text,
|
| 83 |
+
return_tensors="pt",
|
| 84 |
+
padding=True,
|
| 85 |
+
truncation=True,
|
| 86 |
+
max_length=512
|
| 87 |
+
).to(self.device)
|
| 88 |
+
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
outputs = self.encoder(**inputs)
|
| 91 |
+
# Use mean pooling over sequence
|
| 92 |
+
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten()
|
| 93 |
+
|
| 94 |
+
return embedding
|
| 95 |
+
|
| 96 |
+
def get_action(self, text: str) -> Tuple[int, np.ndarray]:
|
| 97 |
+
"""
|
| 98 |
+
Get CQL agent's decision for the input text
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
text: User input text
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Tuple of (action_index, q_values)
|
| 105 |
+
"""
|
| 106 |
+
# Encode text to embedding
|
| 107 |
+
embedding = self.encode_text(text)
|
| 108 |
+
|
| 109 |
+
# Get action from CQL agent
|
| 110 |
+
action = self.cql_agent.select_action(embedding, evaluate=True)
|
| 111 |
+
|
| 112 |
+
# Get Q-values for all actions (for visualization)
|
| 113 |
+
state = self.cql_agent.normalizer.normalize(embedding)
|
| 114 |
+
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
q_values = self.cql_agent.critic_1(state_tensor).cpu().numpy().flatten()
|
| 118 |
+
|
| 119 |
+
return action, q_values
|
| 120 |
+
|
| 121 |
+
def chat(self, user_message: str, temperature: float = 0.7) -> Dict:
|
| 122 |
+
"""
|
| 123 |
+
Main chat function - processes user message and generates response
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
user_message: User's input message
|
| 127 |
+
temperature: Response creativity
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Dictionary containing response and metadata
|
| 131 |
+
"""
|
| 132 |
+
# Get CQL agent's decision
|
| 133 |
+
action, q_values = self.get_action(user_message)
|
| 134 |
+
action_name = config.ACTION_MAPPING[action]
|
| 135 |
+
|
| 136 |
+
print(f"\n🤖 CQL Decision: {action_name} (Action {action})")
|
| 137 |
+
print(f"📊 Q-values: {q_values}")
|
| 138 |
+
|
| 139 |
+
# Initialize variables
|
| 140 |
+
response_text = ""
|
| 141 |
+
image_path = None
|
| 142 |
+
|
| 143 |
+
# IMPROVED LOGIC: Check for drawing keywords FIRST
|
| 144 |
+
# Override CQL decision if drawing keywords detected
|
| 145 |
+
drawing_keywords = ['vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh', 'draw', 'paint', 'create image', 'generate']
|
| 146 |
+
is_drawing_request = any(keyword in user_message.lower() for keyword in drawing_keywords)
|
| 147 |
+
|
| 148 |
+
# Force Drawing Agent if keywords detected
|
| 149 |
+
if is_drawing_request:
|
| 150 |
+
print("🎨 Drawing keywords detected! Forcing Drawing Agent.")
|
| 151 |
+
action = 1
|
| 152 |
+
action_name = config.ACTION_MAPPING[1]
|
| 153 |
+
|
| 154 |
+
# Execute based on final action
|
| 155 |
+
if action == 0: # Communication Agent
|
| 156 |
+
response_text = self.communication_agent.generate_response(
|
| 157 |
+
user_message,
|
| 158 |
+
self.conversation_history,
|
| 159 |
+
temperature
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
elif action == 1: # Drawing Agent
|
| 163 |
+
response_text, image_path = self.drawing_agent.generate_sketch(user_message)
|
| 164 |
+
|
| 165 |
+
elif action == 2: # Clarification - fallback to Communication
|
| 166 |
+
print("⚠️ CQL suggested Clarification. Using Communication Agent.")
|
| 167 |
+
response_text = self.communication_agent.generate_response(
|
| 168 |
+
user_message,
|
| 169 |
+
self.conversation_history,
|
| 170 |
+
temperature
|
| 171 |
+
)
|
| 172 |
+
action = 0
|
| 173 |
+
action_name = config.ACTION_MAPPING[0]
|
| 174 |
+
|
| 175 |
+
# Update conversation history
|
| 176 |
+
self.conversation_history.append({
|
| 177 |
+
'role': 'user',
|
| 178 |
+
'content': user_message
|
| 179 |
+
})
|
| 180 |
+
self.conversation_history.append({
|
| 181 |
+
'role': 'assistant',
|
| 182 |
+
'content': response_text,
|
| 183 |
+
'action': action,
|
| 184 |
+
'action_name': action_name
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
# Limit history length
|
| 188 |
+
if len(self.conversation_history) > config.MAX_HISTORY_LENGTH:
|
| 189 |
+
self.conversation_history = self.conversation_history[-config.MAX_HISTORY_LENGTH:]
|
| 190 |
+
|
| 191 |
+
# Save to memory manager if available
|
| 192 |
+
if self.memory_manager:
|
| 193 |
+
self.memory_manager.save_message('user', user_message)
|
| 194 |
+
self.memory_manager.save_message(
|
| 195 |
+
'assistant',
|
| 196 |
+
response_text,
|
| 197 |
+
{
|
| 198 |
+
'action': action,
|
| 199 |
+
'action_name': action_name,
|
| 200 |
+
'q_values': q_values.tolist()
|
| 201 |
+
}
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return {
|
| 205 |
+
'response': response_text,
|
| 206 |
+
'action': action,
|
| 207 |
+
'action_name': action_name,
|
| 208 |
+
'q_values': q_values.tolist(),
|
| 209 |
+
'image_path': image_path
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
def _generate_clarification_request(self, user_message: str) -> str:
|
| 213 |
+
"""Generate a clarification request when input is unclear"""
|
| 214 |
+
clarifications = [
|
| 215 |
+
f"Xin lỗi, tôi chưa hiểu rõ yêu cầu của bạn: '{user_message}'. Bạn có thể nói rõ hơn được không?",
|
| 216 |
+
f"Tôi cần thêm thông tin để hiểu câu hỏi của bạn. Bạn muốn tôi làm gì với: '{user_message}'?",
|
| 217 |
+
f"Câu hỏi của bạn chưa rõ ràng. Bạn có thể diễn đạt lại không?",
|
| 218 |
+
f"Hmm, tôi không chắc bạn đang hỏi gì. Bạn có thể cung cấp thêm chi tiết không?"
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
import random
|
| 222 |
+
return random.choice(clarifications)
|
| 223 |
+
|
| 224 |
+
def clear_history(self):
|
| 225 |
+
"""Clear conversation history"""
|
| 226 |
+
self.conversation_history = []
|
| 227 |
+
print("🗑️ Conversation history cleared")
|
| 228 |
+
|
| 229 |
+
def get_action_distribution(self) -> Dict[str, int]:
|
| 230 |
+
"""Get distribution of actions taken in current conversation"""
|
| 231 |
+
distribution = {name: 0 for name in config.ACTION_MAPPING.values()}
|
| 232 |
+
|
| 233 |
+
for msg in self.conversation_history:
|
| 234 |
+
if msg.get('role') == 'assistant' and 'action_name' in msg:
|
| 235 |
+
action_name = msg['action_name']
|
| 236 |
+
distribution[action_name] = distribution.get(action_name, 0) + 1
|
| 237 |
+
|
| 238 |
+
return distribution
|
communication_agent.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Communication Agent - Gemini API Version (Stable)
|
| 3 |
+
Backup GPT-2 version available in communication_agent_gpt2.py
|
| 4 |
+
"""
|
| 5 |
+
import google.generativeai as genai
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
import os
|
| 8 |
+
import config
|
| 9 |
+
|
| 10 |
+
class CommunicationAgent:
|
| 11 |
+
def __init__(self, api_key: str = None):
|
| 12 |
+
"""Initialize Communication Agent with Gemini API"""
|
| 13 |
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
|
| 14 |
+
|
| 15 |
+
if self.api_key:
|
| 16 |
+
genai.configure(api_key=self.api_key)
|
| 17 |
+
self.model = genai.GenerativeModel(config.GEMINI_MODEL)
|
| 18 |
+
self.enabled = True
|
| 19 |
+
print("✅ Communication Agent (Gemini) ready!")
|
| 20 |
+
else:
|
| 21 |
+
self.model = None
|
| 22 |
+
self.enabled = False
|
| 23 |
+
print("⚠️ Warning: GEMINI_API_KEY not set. Please add it to .env file")
|
| 24 |
+
|
| 25 |
+
# System context - Clear instructions for better responses
|
| 26 |
+
self.system_context = (
|
| 27 |
+
"You are a helpful, friendly AI assistant. "
|
| 28 |
+
"Respond naturally and conversationally. "
|
| 29 |
+
"Keep responses concise (2-3 sentences). "
|
| 30 |
+
"Be warm and engaging. "
|
| 31 |
+
"If you don't understand, ask for clarification politely."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def generate_response(
|
| 35 |
+
self,
|
| 36 |
+
user_message: str,
|
| 37 |
+
conversation_history: List[Dict] = None,
|
| 38 |
+
temperature: float = 0.7
|
| 39 |
+
) -> str:
|
| 40 |
+
"""
|
| 41 |
+
Generate a conversational response using Gemini
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
user_message: User's input message
|
| 45 |
+
conversation_history: Previous conversation context
|
| 46 |
+
temperature: Response creativity (0.0-1.0)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Generated response text
|
| 50 |
+
"""
|
| 51 |
+
if not self.enabled:
|
| 52 |
+
return "⚠️ Gemini API not configured. Please add GEMINI_API_KEY to .env file."
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# Build context from history
|
| 56 |
+
context = self._build_context(conversation_history)
|
| 57 |
+
|
| 58 |
+
# Create prompt with system context
|
| 59 |
+
prompt = f"""{self.system_context}
|
| 60 |
+
|
| 61 |
+
{context}
|
| 62 |
+
|
| 63 |
+
User: {user_message}
|
| 64 |
+
Assistant:"""
|
| 65 |
+
|
| 66 |
+
# Generate response with Gemini
|
| 67 |
+
generation_config = {
|
| 68 |
+
'temperature': temperature,
|
| 69 |
+
'top_p': 0.95,
|
| 70 |
+
'top_k': 40,
|
| 71 |
+
'max_output_tokens': 200,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
response = self.model.generate_content(
|
| 75 |
+
prompt,
|
| 76 |
+
generation_config=generation_config
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return response.text.strip()
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
return f"Sorry, an error occurred: {str(e)}"
|
| 83 |
+
|
| 84 |
+
def _build_context(self, conversation_history: List[Dict] = None) -> str:
|
| 85 |
+
"""Build conversation context from history"""
|
| 86 |
+
if not conversation_history:
|
| 87 |
+
return ""
|
| 88 |
+
|
| 89 |
+
context = "Previous conversation:\n"
|
| 90 |
+
# Use last 5 messages for context
|
| 91 |
+
for msg in conversation_history[-5:]:
|
| 92 |
+
role = "User" if msg.get('role') == 'user' else "Assistant"
|
| 93 |
+
content = msg.get('content', '')
|
| 94 |
+
context += f"{role}: {content}\n"
|
| 95 |
+
|
| 96 |
+
return context
|
config.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration file for CQL Chatbot System
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
# Load environment variables
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
# Project paths
|
| 12 |
+
PROJECT_ROOT = Path(__file__).parent
|
| 13 |
+
MODEL_PATH = PROJECT_ROOT / "Conservative Q-learning" / "saved_agent_1"
|
| 14 |
+
|
| 15 |
+
# Model configuration
|
| 16 |
+
STATE_DIM = 768 # T5-base embedding dimension
|
| 17 |
+
ACTION_DIM = 3 # 0: Chat, 1: Sketch, 2: Clarify
|
| 18 |
+
DEVICE = os.getenv("DEVICE", "cuda")
|
| 19 |
+
|
| 20 |
+
# Action mapping
|
| 21 |
+
ACTION_MAPPING = {
|
| 22 |
+
0: "Communication Agent",
|
| 23 |
+
1: "Drawing Agent",
|
| 24 |
+
2: "Clarification"
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
ACTION_DESCRIPTIONS = {
|
| 28 |
+
0: "💬 Trò chuyện thông thường",
|
| 29 |
+
1: "🎨 Vẽ sketch/hình ảnh",
|
| 30 |
+
2: "❓ Cần làm rõ thêm"
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# T5 Model
|
| 34 |
+
T5_MODEL_NAME = "t5-base"
|
| 35 |
+
|
| 36 |
+
# Gemini API (Primary - Stable)
|
| 37 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
|
| 38 |
+
GEMINI_MODEL = "gemini-2.5-flash"
|
| 39 |
+
|
| 40 |
+
# GPT-2 Model (Backup - Local, available in communication_agent_gpt2.py)
|
| 41 |
+
GPT2_MODEL_NAME = "gpt2" # Can use "gpt2-medium" or "gpt2-large" for better quality
|
| 42 |
+
|
| 43 |
+
# Streamlit UI Configuration
|
| 44 |
+
APP_TITLE = "🤖 Chatbot Wisdom - CQL Multi-Agent System"
|
| 45 |
+
APP_ICON = "🧠"
|
| 46 |
+
SIDEBAR_TITLE = "⚙️ Cài đặt"
|
| 47 |
+
|
| 48 |
+
# Chat settings
|
| 49 |
+
MAX_HISTORY_LENGTH = 50
|
| 50 |
+
DEFAULT_TEMPERATURE = 0.7
|
| 51 |
+
DEFAULT_MAX_TOKENS = 512
|
drawing_agent.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Drawing Agent - Integrated with Stable Diffusion Model
|
| 3 |
+
Uses pre-trained Diffusion model from Agent_Diffusion folder
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'Agent_Diffusion'))
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import random
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from diffusers import StableDiffusionPipeline
|
| 15 |
+
from stable_diffusion import MiniDiffusionPipeline
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DrawingAgent:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
"""Initialize Drawing Agent with Diffusion Model"""
|
| 21 |
+
self.enabled = True
|
| 22 |
+
self.output_dir = Path("generated_images")
|
| 23 |
+
self.output_dir.mkdir(exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# Model paths
|
| 26 |
+
self.base_model_id = "runwayml/stable-diffusion-v1-5"
|
| 27 |
+
self.unet_path = Path("Agent_Diffusion/unet-mini.safetensors")
|
| 28 |
+
self.vae_path = Path("Agent_Diffusion/vae-finetuned.safetensors")
|
| 29 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
+
|
| 31 |
+
# Model configuration
|
| 32 |
+
self.tiny_unet_config = {
|
| 33 |
+
"unet_block_out_channels": (128, 256, 512),
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Lazy loading - only load when first needed
|
| 37 |
+
self.pipeline = None
|
| 38 |
+
self.model_loaded = False
|
| 39 |
+
|
| 40 |
+
print("✅ Drawing Agent initialized (Diffusion model will load on first use)")
|
| 41 |
+
|
| 42 |
+
def _load_model(self):
|
| 43 |
+
"""Load Diffusion model (lazy loading)"""
|
| 44 |
+
if self.model_loaded:
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
print("🎨 Loading Diffusion model...")
|
| 49 |
+
|
| 50 |
+
# Initialize MiniDiffusionPipeline
|
| 51 |
+
container = MiniDiffusionPipeline(
|
| 52 |
+
base_model_id=self.base_model_id,
|
| 53 |
+
device=self.device,
|
| 54 |
+
config_overrides=self.tiny_unet_config
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Load UNet weights
|
| 58 |
+
print(f"Loading UNet from {self.unet_path}...")
|
| 59 |
+
unet_weights = torch.load(str(self.unet_path), map_location=self.device)
|
| 60 |
+
container.unet.load_state_dict(unet_weights)
|
| 61 |
+
|
| 62 |
+
# Load VAE weights
|
| 63 |
+
print(f"Loading VAE from {self.vae_path}...")
|
| 64 |
+
vae_weights = torch.load(str(self.vae_path), map_location=self.device)
|
| 65 |
+
container.vae.load_state_dict(vae_weights)
|
| 66 |
+
|
| 67 |
+
# Create inference pipeline
|
| 68 |
+
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 69 |
+
|
| 70 |
+
self.pipeline = StableDiffusionPipeline(
|
| 71 |
+
unet=container.unet,
|
| 72 |
+
vae=container.vae,
|
| 73 |
+
text_encoder=container.text_encoder,
|
| 74 |
+
tokenizer=container.tokenizer,
|
| 75 |
+
scheduler=container.noise_scheduler,
|
| 76 |
+
safety_checker=None,
|
| 77 |
+
feature_extractor=None,
|
| 78 |
+
).to(self.device)
|
| 79 |
+
|
| 80 |
+
if self.device == "cuda":
|
| 81 |
+
self.pipeline.to(dtype=torch_dtype)
|
| 82 |
+
|
| 83 |
+
self.pipeline.set_progress_bar_config(disable=True)
|
| 84 |
+
self.model_loaded = True
|
| 85 |
+
|
| 86 |
+
print("✅ Diffusion model loaded successfully!")
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"❌ Error loading Diffusion model: {e}")
|
| 90 |
+
self.enabled = False
|
| 91 |
+
|
| 92 |
+
def generate_sketch(self, prompt: str) -> Tuple[str, Optional[str]]:
|
| 93 |
+
"""
|
| 94 |
+
Generate a sketch based on the prompt using Diffusion model
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
prompt: Description of what to draw
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Tuple of (response_text, image_path)
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
# Load model if not loaded
|
| 104 |
+
if not self.model_loaded:
|
| 105 |
+
self._load_model()
|
| 106 |
+
|
| 107 |
+
if not self.enabled or self.pipeline is None:
|
| 108 |
+
return self._fallback_response(prompt)
|
| 109 |
+
|
| 110 |
+
# Parse the sketch request
|
| 111 |
+
clean_prompt = self.parse_sketch_request(prompt)
|
| 112 |
+
|
| 113 |
+
print(f"🎨 Generating image for: '{clean_prompt}'")
|
| 114 |
+
|
| 115 |
+
# Generate image
|
| 116 |
+
current_seed = random.randint(0, 2**32 - 1)
|
| 117 |
+
generator = torch.Generator(device=self.device).manual_seed(current_seed)
|
| 118 |
+
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
image = self.pipeline(
|
| 121 |
+
prompt=clean_prompt,
|
| 122 |
+
num_inference_steps=50,
|
| 123 |
+
generator=generator,
|
| 124 |
+
guidance_scale=7.5
|
| 125 |
+
).images[0]
|
| 126 |
+
|
| 127 |
+
# Save image
|
| 128 |
+
timestamp = Path(f"{random.randint(1000, 9999)}.png")
|
| 129 |
+
image_path = self.output_dir / timestamp
|
| 130 |
+
image.save(str(image_path))
|
| 131 |
+
|
| 132 |
+
print(f"✅ Image saved to: {image_path}")
|
| 133 |
+
|
| 134 |
+
# Response text
|
| 135 |
+
response_text = f"🎨 **Image Generated!**\\n\\n"
|
| 136 |
+
response_text += f"Prompt: **{clean_prompt}**\\n"
|
| 137 |
+
response_text += f"Seed: {current_seed}\\n\\n"
|
| 138 |
+
response_text += "Here's your generated image:"
|
| 139 |
+
|
| 140 |
+
return response_text, str(image_path)
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
error_msg = f"Sorry, I encountered an error while generating the image: {str(e)}"
|
| 144 |
+
print(f"❌ Drawing error: {e}")
|
| 145 |
+
return error_msg, None
|
| 146 |
+
|
| 147 |
+
def _fallback_response(self, prompt: str) -> Tuple[str, None]:
|
| 148 |
+
"""Fallback response when model can't load"""
|
| 149 |
+
clean_prompt = self.parse_sketch_request(prompt)
|
| 150 |
+
|
| 151 |
+
response_text = f"🎨 **Drawing Request Received**\\n\\n"
|
| 152 |
+
response_text += f"I understand you want me to draw: **{clean_prompt}**\\n\\n"
|
| 153 |
+
response_text += "⚠️ **Note**: Diffusion model failed to load. "
|
| 154 |
+
response_text += "Please check that the model files exist in Agent_Diffusion folder."
|
| 155 |
+
|
| 156 |
+
return response_text, None
|
| 157 |
+
|
| 158 |
+
def parse_sketch_request(self, user_message: str) -> str:
|
| 159 |
+
"""
|
| 160 |
+
Parse the sketch request to extract key details
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
user_message: User's request message
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Cleaned prompt for image generation
|
| 167 |
+
"""
|
| 168 |
+
# Remove common drawing keywords to get the core subject
|
| 169 |
+
drawing_keywords = [
|
| 170 |
+
'vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh',
|
| 171 |
+
'draw', 'paint', 'create', 'make', 'generate',
|
| 172 |
+
'cho tôi', 'cho mình', 'giúp tôi', 'help me', 'for me',
|
| 173 |
+
'một', 'a', 'an', 'the'
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
prompt = user_message.lower()
|
| 177 |
+
|
| 178 |
+
# Remove keywords
|
| 179 |
+
for keyword in drawing_keywords:
|
| 180 |
+
prompt = prompt.replace(keyword, '')
|
| 181 |
+
|
| 182 |
+
# Clean up
|
| 183 |
+
prompt = ' '.join(prompt.split()) # Remove extra spaces
|
| 184 |
+
prompt = prompt.strip()
|
| 185 |
+
|
| 186 |
+
# If empty after cleaning, use original
|
| 187 |
+
if not prompt or len(prompt) < 3:
|
| 188 |
+
prompt = user_message
|
| 189 |
+
|
| 190 |
+
return prompt
|
| 191 |
+
|
| 192 |
+
def is_drawing_request(self, user_message: str) -> bool:
|
| 193 |
+
"""
|
| 194 |
+
Check if the message is a drawing request
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
user_message: User's message
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
True if it's a drawing request
|
| 201 |
+
"""
|
| 202 |
+
drawing_keywords = [
|
| 203 |
+
'vẽ', 'sketch', 'phác thảo', 'draw', 'paint',
|
| 204 |
+
'create image', 'generate image', 'make picture'
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
message_lower = user_message.lower()
|
| 208 |
+
return any(keyword in message_lower for keyword in drawing_keywords)
|
memory_manager.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Memory Manager - Lưu trữ và quản lý lịch sử hội thoại
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import List, Dict, Optional
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MemoryManager:
|
| 12 |
+
def __init__(self, storage_dir: str = "conversation_history"):
|
| 13 |
+
"""
|
| 14 |
+
Initialize Memory Manager
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
storage_dir: Thư mục lưu trữ lịch sử hội thoại
|
| 18 |
+
"""
|
| 19 |
+
self.storage_dir = Path(storage_dir)
|
| 20 |
+
self.storage_dir.mkdir(exist_ok=True)
|
| 21 |
+
self.current_session_file = None
|
| 22 |
+
|
| 23 |
+
def create_new_session(self) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Tạo session mới
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Session ID
|
| 29 |
+
"""
|
| 30 |
+
session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 31 |
+
self.current_session_file = self.storage_dir / f"session_{session_id}.json"
|
| 32 |
+
|
| 33 |
+
# Tạo file session mới
|
| 34 |
+
session_data = {
|
| 35 |
+
"session_id": session_id,
|
| 36 |
+
"created_at": datetime.now().isoformat(),
|
| 37 |
+
"messages": []
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
self._save_session(session_data)
|
| 41 |
+
return session_id
|
| 42 |
+
|
| 43 |
+
def save_message(self, role: str, content: str, metadata: Dict = None):
|
| 44 |
+
"""
|
| 45 |
+
Lưu tin nhắn vào session hiện tại
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
role: 'user' hoặc 'assistant'
|
| 49 |
+
content: Nội dung tin nhắn
|
| 50 |
+
metadata: Thông tin bổ sung (action, q_values, etc.)
|
| 51 |
+
"""
|
| 52 |
+
if not self.current_session_file:
|
| 53 |
+
self.create_new_session()
|
| 54 |
+
|
| 55 |
+
session_data = self._load_session()
|
| 56 |
+
|
| 57 |
+
message = {
|
| 58 |
+
"role": role,
|
| 59 |
+
"content": content,
|
| 60 |
+
"timestamp": datetime.now().isoformat()
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
if metadata:
|
| 64 |
+
message.update(metadata)
|
| 65 |
+
|
| 66 |
+
session_data["messages"].append(message)
|
| 67 |
+
self._save_session(session_data)
|
| 68 |
+
|
| 69 |
+
def load_session(self, session_id: str) -> Optional[Dict]:
|
| 70 |
+
"""
|
| 71 |
+
Load session theo ID
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
session_id: ID của session cần load
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Session data hoặc None nếu không tìm thấy
|
| 78 |
+
"""
|
| 79 |
+
session_file = self.storage_dir / f"session_{session_id}.json"
|
| 80 |
+
|
| 81 |
+
if not session_file.exists():
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
with open(session_file, 'r', encoding='utf-8') as f:
|
| 85 |
+
return json.load(f)
|
| 86 |
+
|
| 87 |
+
def get_all_sessions(self) -> List[Dict]:
|
| 88 |
+
"""
|
| 89 |
+
Lấy danh sách tất cả sessions
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
List các session info
|
| 93 |
+
"""
|
| 94 |
+
sessions = []
|
| 95 |
+
|
| 96 |
+
for file in self.storage_dir.glob("session_*.json"):
|
| 97 |
+
try:
|
| 98 |
+
with open(file, 'r', encoding='utf-8') as f:
|
| 99 |
+
data = json.load(f)
|
| 100 |
+
sessions.append({
|
| 101 |
+
"session_id": data["session_id"],
|
| 102 |
+
"created_at": data["created_at"],
|
| 103 |
+
"message_count": len(data["messages"]),
|
| 104 |
+
"file": str(file)
|
| 105 |
+
})
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"Error loading session {file}: {e}")
|
| 108 |
+
|
| 109 |
+
# Sắp xếp theo thời gian tạo (mới nhất trước)
|
| 110 |
+
sessions.sort(key=lambda x: x["created_at"], reverse=True)
|
| 111 |
+
return sessions
|
| 112 |
+
|
| 113 |
+
def delete_session(self, session_id: str) -> bool:
|
| 114 |
+
"""
|
| 115 |
+
Xóa session
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
session_id: ID của session cần xóa
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
True nếu xóa thành công
|
| 122 |
+
"""
|
| 123 |
+
session_file = self.storage_dir / f"session_{session_id}.json"
|
| 124 |
+
|
| 125 |
+
if session_file.exists():
|
| 126 |
+
session_file.unlink()
|
| 127 |
+
return True
|
| 128 |
+
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
def _load_session(self) -> Dict:
|
| 132 |
+
"""Load session hiện tại"""
|
| 133 |
+
if not self.current_session_file or not self.current_session_file.exists():
|
| 134 |
+
return {
|
| 135 |
+
"session_id": "default",
|
| 136 |
+
"created_at": datetime.now().isoformat(),
|
| 137 |
+
"messages": []
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
with open(self.current_session_file, 'r', encoding='utf-8') as f:
|
| 141 |
+
return json.load(f)
|
| 142 |
+
|
| 143 |
+
def _save_session(self, session_data: Dict):
|
| 144 |
+
"""Lưu session data"""
|
| 145 |
+
if not self.current_session_file:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
with open(self.current_session_file, 'w', encoding='utf-8') as f:
|
| 149 |
+
json.dump(session_data, f, ensure_ascii=False, indent=2)
|
| 150 |
+
|
| 151 |
+
def get_current_messages(self) -> List[Dict]:
|
| 152 |
+
"""Lấy tất cả messages của session hiện tại"""
|
| 153 |
+
session_data = self._load_session()
|
| 154 |
+
return session_data.get("messages", [])
|
| 155 |
+
|
| 156 |
+
def export_session(self, session_id: str, output_file: str):
|
| 157 |
+
"""
|
| 158 |
+
Export session ra file
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
session_id: ID của session
|
| 162 |
+
output_file: Đường dẫn file output
|
| 163 |
+
"""
|
| 164 |
+
session_data = self.load_session(session_id)
|
| 165 |
+
|
| 166 |
+
if not session_data:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 170 |
+
json.dump(session_data, f, ensure_ascii=False, indent=2)
|
| 171 |
+
|
| 172 |
+
return True
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.31.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.36.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
pillow>=10.0.0
|
| 6 |
+
google-generativeai>=0.3.0
|
| 7 |
+
python-dotenv>=1.0.0
|
| 8 |
+
plotly>=5.18.0
|
| 9 |
+
sentencepiece>=0.1.99
|
| 10 |
+
diffusers>=0.25.0
|
| 11 |
+
accelerate>=0.25.0
|
| 12 |
+
safetensors>=0.4.0
|