NguyenThanh1405 commited on
Commit
4cfe4fa
·
1 Parent(s): 2eebc40

Deploy CQL Chatbot (without large files)

Browse files
.gitignore ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+
11
+ # Virtual Environment
12
+ .venv/
13
+ venv/
14
+ ENV/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+
22
+ # Environment variables (IMPORTANT: Don't upload API keys!)
23
+ .env
24
+
25
+ # Conversation history (user data)
26
+ conversation_history/
27
+
28
+ # Generated images (will be created on server)
29
+ generated_images/
30
+
31
+ # Model cache
32
+ .cache/
33
+
34
+ # Logs
35
+ *.log
36
+
37
+ # OS
38
+ .DS_Store
39
+ Thumbs.db
40
+
41
+ # Large model files (upload separately to HF)
42
+ # Uncomment if models are too large for git
43
+ # Agent_Diffusion/*.safetensors
44
+ # Conservative Q-learning/saved_agent_1/*.pth
45
+ # Conservative Q-learning/saved_agent_1/*.pkl
46
+
47
+ # Backup files
48
+ *_backup.py
49
+ *_gpt2.py
50
+ communication_agent_gemini_backup.py
51
+
52
+ # Antigravity artifacts
53
+ .gemini/
Agent_Diffusion/inference_v3.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from transformers import Swin2SRImageProcessor, Swin2SRForImageSuperResolution
4
+ from diffusers import StableDiffusionPipeline
5
+ import numpy as np
6
+ from PIL import Image, ImageEnhance, ImageOps
7
+ import random
8
+ # from safetensors.torch import load_file
9
+
10
+ from stable_diffusion import MiniDiffusionPipeline
11
+
12
+ # --- Cấu hình ---
13
+ #PROMPT = "beautiful woman with long braided hair, wearing a scarf, soft smile, looking down, detailed shading" #725562173
14
+ #PROMPT = "attractive woman, big lips, mouth slightly open, heavy makeup" #v5
15
+ #PROMPT = "The man is young and has sharp jawline, narrow eyes, thick eyebrows, and short black hair." #10, 11
16
+ #PROMPT = "She is elderly with deep smile lines, small eyes, and short curly gray hair." #13
17
+ #PROMPT = "This man is old and smiling, with gray beard and big nose"
18
+
19
+
20
+ PROMPT = "a baby"
21
+ SAVE_IMAGE_PATH = "./15.png"
22
+
23
+ UNET_SAFE_PATH = "./unet-mini.safetensors"
24
+ VAE_SAFE_PATH = "./vae-finetuned.safetensors"
25
+
26
+ BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+
29
+
30
+ TINY_UNET_CONFIG = {
31
+ "unet_block_out_channels": (128, 256, 512),
32
+ }
33
+
34
+ MODEL_ID = "caidas/swin2SR-classical-sr-x4-64"
35
+ print(f"Đang load model {MODEL_ID} từ Hugging Face...")
36
+ processor = Swin2SRImageProcessor.from_pretrained(MODEL_ID)
37
+ model = Swin2SRForImageSuperResolution.from_pretrained(MODEL_ID)
38
+ print("Load model thành công!")
39
+
40
+ model = model.to(DEVICE)
41
+
42
+ def upscale_image_pipeline(pil_image, contrast=1.3, sharpen=1.5, target_size=(512, 512)):
43
+ """
44
+ Chiến thuật "Canvas Isolation" (Cách ly khung tranh).
45
+ Đặt ảnh vào giữa một vùng trắng cực rộng để đẩy lỗi biên ra xa.
46
+ """
47
+ if model is None or processor is None:
48
+ return pil_image.resize(target_size)
49
+
50
+ # 1. Chuẩn bị ảnh
51
+ img_np = np.array(pil_image)
52
+ if len(img_np.shape) == 2:
53
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
54
+
55
+ h_orig, w_orig = img_np.shape[:2]
56
+
57
+ # 2. TẠO CANVAS (Khung tranh) LỚN
58
+ # Tạo một nền trắng to gấp đôi ảnh gốc (256x256)
59
+ # Mục đích: Đưa ảnh thật vào "vùng an toàn" ở trung tâm tuyệt đối
60
+ canvas_size = 256
61
+ canvas = np.ones((canvas_size, canvas_size, 3), dtype=np.uint8) * 255
62
+
63
+ # Tính tọa độ để dán ảnh vào giữa
64
+ y_offset = (canvas_size - h_orig) // 2 # (256-128)/2 = 64
65
+ x_offset = (canvas_size - w_orig) // 2 # 64
66
+
67
+ # Dán ảnh vào canvas
68
+ canvas[y_offset:y_offset+h_orig, x_offset:x_offset+w_orig] = img_np
69
+
70
+ # 3. Upscale toàn bộ Canvas
71
+ # Lúc này model sẽ xử lý biên của ảnh 256x256 -> Lỗi phản chiếu sẽ nằm ở rìa canvas (cách ảnh thật rất xa)
72
+ pil_canvas = Image.fromarray(canvas)
73
+ inputs = processor(pil_canvas, return_tensors="pt").to(DEVICE)
74
+
75
+ with torch.no_grad():
76
+ outputs = model(**inputs)
77
+
78
+ output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
79
+ output_tensor = np.moveaxis(output_tensor, 0, -1)
80
+ output_canvas = (output_tensor * 255.0).round().astype(np.uint8)
81
+
82
+ # 4. TRÍCH XUẤT ẢNH THẬT (CROP)
83
+ # Canvas input 256 -> Upscale x4 -> Canvas output 1024
84
+ # Ảnh thật nằm ở vị trí offset * 4
85
+ scale_factor = 4
86
+ y_start = y_offset * scale_factor # 64 * 4 = 256
87
+ x_start = x_offset * scale_factor # 256
88
+
89
+ # Kích thước ảnh thật sau khi upscale (128 * 4 = 512)
90
+ h_real = h_orig * scale_factor
91
+ w_real = w_orig * scale_factor
92
+
93
+ # Cắt lấy đúng phần ảnh thật nằm giữa canvas
94
+ final_img = output_canvas[y_start : y_start + h_real, x_start : x_start + w_real]
95
+
96
+ # 5. BIỆN PHÁP CƯỠNG BỨC (HARD FIX)
97
+ # Nếu model vẫn "lì lợm" tạo ra 1-2 pixel mờ ở đáy, ta sẽ tô trắng 3 dòng pixel cuối cùng.
98
+ # Vì đây là tranh vẽ trên nền trắng, việc này không ảnh hưởng nội dung nhưng xóa sạch mọi lỗi.
99
+ final_img[-1:, :, :] = 255
100
+ final_img[:, -1:, :] = 255
101
+
102
+ # 6. Đảm bảo kích thước cuối cùng
103
+ if final_img.shape[:2] != target_size:
104
+ final_img = cv2.resize(final_img, (target_size[1], target_size[0]), interpolation=cv2.INTER_LANCZOS4)
105
+
106
+ # 7. Hậu xử lý
107
+ final_pil = Image.fromarray(final_img)
108
+ enhancer = ImageEnhance.Contrast(final_pil)
109
+ final_pil = enhancer.enhance(contrast)
110
+ enhancer = ImageEnhance.Sharpness(final_pil)
111
+ final_pil = enhancer.enhance(sharpen)
112
+
113
+ return final_pil
114
+
115
+ @torch.no_grad()
116
+ def main():
117
+ print("--- Bắt đầu quá trình Inference (từ Safetensors) ---")
118
+
119
+ # --- Khởi tạo MiniDiffusionPipeline ---
120
+ print(f"Đang tải pipeline gốc từ {BASE_MODEL_ID}...")
121
+ container = MiniDiffusionPipeline(
122
+ base_model_id=BASE_MODEL_ID,
123
+ device=DEVICE,
124
+ config_overrides=TINY_UNET_CONFIG
125
+ )
126
+
127
+ # --- Tải trọng số đã huấn luyện ---
128
+
129
+ # Tải UNet
130
+ print(f"Đang tải trọng số UNet từ {UNET_SAFE_PATH}...")
131
+ try:
132
+ unet_weights = torch.load(UNET_SAFE_PATH, map_location=DEVICE)
133
+ container.unet.load_state_dict(unet_weights)
134
+ except Exception as e:
135
+ print(f"LỖI: Không thể tải UNet state dict: {e}")
136
+ print("Kiểm tra xem bạn đã bỏ chú thích 'config_overrides=TINY_UNET_CONFIG' chưa?")
137
+ return
138
+
139
+ # Tải VAE
140
+ print(f"Đang tải trọng số VAE từ {VAE_SAFE_PATH}...")
141
+ try:
142
+ vae_weights = torch.load(VAE_SAFE_PATH, map_location=DEVICE)
143
+ container.vae.load_state_dict(vae_weights)
144
+ except Exception as e:
145
+ print(f"LỖI: Không thể tải VAE state dict: {e}")
146
+ return
147
+
148
+ # --- Khởi tạo StableDiffusionPipeline ---
149
+ torch_dtype = torch.float16 if DEVICE.startswith("cuda") else torch.float32
150
+
151
+ print("Đang tạo pipeline inference...")
152
+ inference_pipeline = StableDiffusionPipeline(
153
+ unet=container.unet,
154
+ vae=container.vae,
155
+ text_encoder=container.text_encoder,
156
+ tokenizer=container.tokenizer,
157
+ scheduler=container.noise_scheduler,
158
+ safety_checker=None,
159
+ feature_extractor=None,
160
+ ).to(DEVICE)
161
+
162
+ if DEVICE.startswith("cuda"):
163
+ inference_pipeline.to(dtype=torch_dtype)
164
+
165
+ inference_pipeline.set_progress_bar_config(disable=False)
166
+
167
+ # --- Tạo ảnh ---
168
+ print(f"\nĐang tạo ảnh cho prompt: '{PROMPT}'")
169
+ current_seed = random.randint(0, 2**32 - 1)
170
+ print(f"Seed hiện tại: {current_seed}")
171
+
172
+ generator = torch.Generator(device=DEVICE).manual_seed(current_seed) #725562173, 4169604779, 725562172, 3884820838, 1794046812, 1379970385
173
+
174
+ image = inference_pipeline(
175
+ prompt=PROMPT,
176
+ num_inference_steps=50,
177
+ generator=generator,
178
+ guidance_scale=7.5
179
+ ).images[0]
180
+
181
+ final_image = upscale_image_pipeline(image)
182
+ final_image.save(SAVE_IMAGE_PATH)
183
+
184
+ # --- Lưu ảnh ---
185
+ image.save(SAVE_IMAGE_PATH.replace(".png", "_original.png"))
186
+
187
+ # # --- Lưu ảnh ---
188
+ # image.save(SAVE_IMAGE_PATH)
189
+ print(f"\n--- Hoàn thành! ---")
190
+ print(f"Đã lưu ảnh tại: {SAVE_IMAGE_PATH}")
191
+
192
+ try:
193
+ image.show()
194
+ except Exception:
195
+ pass
196
+
197
+ if __name__ == "__main__":
198
+ main()
Agent_Diffusion/stable_diffusion.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models import UNet2DConditionModel, AutoencoderKL
3
+ from diffusers.schedulers import DDPMScheduler
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
+ from typing import Dict, Any, Optional
6
+
7
+ class MiniDiffusionPipeline:
8
+
9
+
10
+ # config mặc định
11
+ DEFAULT_CONFIG: Dict[str, Any] = {
12
+ "beta_schedule": "scaled_linear",
13
+ "beta_start": 0.00085,
14
+ "beta_end": 0.0120,
15
+ "num_train_timesteps": 1000,
16
+ "prediction_type": "epsilon",
17
+ "variance_type": "fixed_small",
18
+ "clip_sample": False,
19
+ "rescale_betas_zero_snr": False,
20
+ "timestep_spacing": "leading",
21
+ "lr": 1e-4,
22
+ "optimizer": "AdamW",
23
+ "scheduler": "cosine",
24
+ "ema_decay": 0.9999,
25
+ "latent_scale": 0.18215,
26
+ "text_embed_dim": 768,
27
+ "latent_channels": 4,
28
+ "latent_downscale_factor": 8,
29
+
30
+ # --- Cấu hình kiến trúc UNet-mini ---
31
+ "image_size": 128,
32
+ "unet_block_out_channels": (256, 512, 1024),
33
+ "unet_layers_per_block": 1,
34
+ "unet_down_block_types": (
35
+ "CrossAttnDownBlock2D",
36
+ "CrossAttnDownBlock2D",
37
+ "DownBlock2D",
38
+ ),
39
+ "unet_up_block_types": (
40
+ "UpBlock2D",
41
+ "CrossAttnUpBlock2D",
42
+ "CrossAttnUpBlock2D",
43
+ ),
44
+ "unet_mid_block_type": "UNetMidBlock2DCrossAttn",
45
+ "unet_attention_head_dim": 8,
46
+ }
47
+
48
+ def __init__(
49
+ self,
50
+ base_model_id: str = "stabilityai/stable-diffusion-v1-5",
51
+ vae_model_id: Optional[str] = None,
52
+ device: str = "cpu",
53
+ config_overrides: Optional[Dict[str, Any]] = None
54
+ ):
55
+ self.device = torch.device(device)
56
+
57
+ self.config = {**self.DEFAULT_CONFIG, **(config_overrides or {})}
58
+
59
+ print(f"Đang tải Tokenizer và Text Encoder (đã đóng băng) từ {base_model_id}...")
60
+ self.tokenizer = self._load_tokenizer(base_model_id)
61
+ self.text_encoder = self._load_text_encoder(base_model_id)
62
+
63
+ _vae_id = vae_model_id or base_model_id
64
+ _vae_subfolder = "vae" if vae_model_id is None else None
65
+ print(f"Đang tải VAE (để fine-tune) từ {_vae_id}...")
66
+ self.vae = self._load_vae(_vae_id, _vae_subfolder)
67
+
68
+ print("Khởi tạo UNet-mini (với trọng số ngẫu nhiên)...")
69
+ self.unet = self._load_mini_unet()
70
+
71
+ print("Khởi tạo Noise Scheduler...")
72
+ self.noise_scheduler = self._load_noise_scheduler()
73
+
74
+ print("\n--- MiniDiffusionPipeline đã sẵn sàng! ---")
75
+ self.print_model_stats()
76
+
77
+ def _load_tokenizer(self, model_id: str) -> CLIPTokenizer:
78
+ return CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
79
+
80
+ def _load_text_encoder(self, model_id: str) -> CLIPTextModel:
81
+ model = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
82
+ model.to(self.device)
83
+ model.requires_grad_(False)
84
+ return model
85
+
86
+ def _load_vae(self, model_id: str, subfolder: Optional[str]) -> AutoencoderKL:
87
+ if subfolder:
88
+ model = AutoencoderKL.from_pretrained(model_id, subfolder=subfolder)
89
+ else:
90
+ model = AutoencoderKL.from_pretrained(model_id)
91
+ model.to(self.device)
92
+ return model
93
+
94
+ def _load_mini_unet(self) -> UNet2DConditionModel:
95
+
96
+ latent_size = self.config["image_size"] // self.config["latent_downscale_factor"]
97
+
98
+ unet_config = {
99
+ "sample_size": latent_size,
100
+ "in_channels": self.config["latent_channels"],
101
+ "out_channels": self.config["latent_channels"],
102
+ "block_out_channels": self.config["unet_block_out_channels"],
103
+ "layers_per_block": self.config["unet_layers_per_block"],
104
+ "down_block_types": self.config["unet_down_block_types"],
105
+ "up_block_types": self.config["unet_up_block_types"],
106
+ "mid_block_type": self.config["unet_mid_block_type"],
107
+ "cross_attention_dim": self.config["text_embed_dim"],
108
+ "attention_head_dim": self.config["unet_attention_head_dim"],
109
+ }
110
+
111
+ model = UNet2DConditionModel(**unet_config)
112
+ model.to(self.device)
113
+ return model
114
+
115
+ def _load_noise_scheduler(self) -> DDPMScheduler:
116
+ return DDPMScheduler.from_config(self.config)
117
+
118
+ def print_model_stats(self):
119
+ unet_params = sum(p.numel() for p in self.unet.parameters() if p.requires_grad)
120
+ vae_params = sum(p.numel() for p in self.vae.parameters() if p.requires_grad)
121
+ print(f" UNet-mini (để train): {unet_params / 1_000_000:.2f} triệu tham số")
122
+ print(f" VAE (để fine-tune): {vae_params / 1_000_000:.2f} triệu tham số")
123
+
124
+ def get_trainable_parameters(self) -> Dict[str, Any]:
125
+ return {
126
+ "unet": self.unet.parameters(),
127
+ "vae": self.vae.parameters()
128
+ }
129
+
130
+
131
+ # --- KHỐI KI��M THỬ (SMOKE TEST) ---
132
+ def _run_smoke_test():
133
+ print("--- Bắt đầu kiểm thử MiniDiffusionPipeline ---")
134
+
135
+ if not torch.cuda.is_available():
136
+ print("CẢNH BÁO: Không tìm thấy CUDA. Chạy trên CPU (sẽ chậm).")
137
+ device = "cpu"
138
+ else:
139
+ device = "cuda"
140
+
141
+ # --- Tải mặc định (dùng VAE của 1.5) ---
142
+ print("\n--- Tải mặc định ---")
143
+ pipeline_1 = MiniDiffusionPipeline(
144
+ base_model_id="runwayml/stable-diffusion-v1-5",
145
+ device=device
146
+ )
147
+
148
+ # --- Tải VAE-MSE ---
149
+ print("\n--- Tải VAE-MSE tùy chỉnh ---")
150
+ pipeline_2 = MiniDiffusionPipeline(
151
+ base_model_id="runwayml/stable-diffusion-v1-5",
152
+ vae_model_id="stabilityai/sd-vae-ft-mse",
153
+ device=device
154
+ )
155
+
156
+ # --- Ghi đè config ---
157
+ print("\n--- Ghi đè config (UNet siêu nhỏ) ---")
158
+ tiny_config = {
159
+ "unet_block_out_channels": (128, 256, 512),
160
+ "lr": 5e-5
161
+ }
162
+ pipeline_3 = MiniDiffusionPipeline(
163
+ base_model_id="runwayml/stable-diffusion-v1-5",
164
+ device=device,
165
+ config_overrides=tiny_config
166
+ )
167
+
168
+ print("\n--- Kiểm thử thành công ---")
169
+ print(f"Config LR của Pipeline 1: {pipeline_1.config['lr']}")
170
+ print(f"Config LR của Pipeline 3: {pipeline_3.config['lr']}")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ _run_smoke_test()
Agent_Diffusion/train_unet.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.utils.data import DataLoader
4
+ from torch.optim import AdamW
5
+ from diffusers.optimization import get_scheduler
6
+ from diffusers import DDPMScheduler, StableDiffusionPipeline
7
+ from PIL import Image
8
+ import os
9
+ import time
10
+ import matplotlib.pyplot as plt
11
+ from tqdm import tqdm
12
+ import random
13
+
14
+ from stable_diffusion import MiniDiffusionPipeline
15
+ from dataset import SketchDataset
16
+
17
+
18
+ from torchmetrics.image.fid import FrechetInceptionDistance
19
+ from torchmetrics.multimodal.clip_score import CLIPScore
20
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
21
+
22
+
23
+ # --- Cấu hình ---
24
+ TRAIN_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\train"
25
+ VAL_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\val"
26
+ VAE_PATH = "./vae-finetuned.safetensors"
27
+ IMAGE_SIZE = 128
28
+ EPOCHS = 101
29
+ BATCH_SIZE = 16 * 5
30
+ LEARNING_RATE = 1e-4 * 5
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ SAVE_UNET_PATH = "./unet-mini.safetensors"
33
+
34
+ CHECKPOINT_PATH = "./unet_latest_checkpoint.pth"
35
+ NUM_INFERENCE_STEPS = 50
36
+
37
+ TINY_UNET_CONFIG = {
38
+ "unet_block_out_channels": (128, 256, 512),
39
+ }
40
+
41
+ def plot_metrics(history, filename="unet_metrics_plot_v1.png"):
42
+ plt.rcParams.update({'font.size': 17})
43
+
44
+ fig, axs = plt.subplots(2, 2, figsize=(15, 12))
45
+
46
+ axs[0, 0].plot(history['train_loss'], label="Train Loss")
47
+ axs[0, 0].plot(history['val_loss'], label="Validation Loss")
48
+ axs[0, 0].set_title("Train vs Validation Loss")
49
+ axs[0, 0].set_xlabel("Epoch")
50
+ axs[0, 0].set_ylabel("MSE Loss")
51
+ axs[0, 0].grid()
52
+ axs[0, 0].legend()
53
+
54
+ axs[0, 1].plot(history['fid'], label="FID", color='green')
55
+ axs[0, 1].set_title("Fréchet Inception Distance (FID)")
56
+ axs[0, 1].set_xlabel("Epoch")
57
+ axs[0, 1].set_ylabel("FID (lower is better)")
58
+ axs[0, 1].grid()
59
+ axs[0, 1].legend()
60
+
61
+ axs[1, 0].plot(history['lpips'], label="LPIPS", color='red')
62
+ axs[1, 0].set_title("Learned Perceptual Image Patch Similarity (LPIPS)")
63
+ axs[1, 0].set_xlabel("Epoch")
64
+ axs[1, 0].set_ylabel("LPIPS (lower is better)")
65
+ axs[1, 0].grid()
66
+ axs[1, 0].legend()
67
+
68
+ axs[1, 1].plot(history['clip_score'], label="CLIP Score", color='purple')
69
+ axs[1, 1].set_title("CLIP Score")
70
+ axs[1, 1].set_xlabel("Epoch")
71
+ axs[1, 1].set_ylabel("CLIP Score (higher is better)")
72
+ axs[1, 1].grid()
73
+ axs[1, 1].legend()
74
+
75
+ plt.tight_layout()
76
+ plt.savefig(filename)
77
+ print(f"Đã lưu biểu đồ metrics tại {filename}")
78
+
79
+ def evaluate(
80
+ eval_pipeline, gen, val_loader, metrics,
81
+ unet, vae, text_encoder, scheduler,
82
+ vae_scale_factor, num_inference_steps
83
+ ):
84
+
85
+ unet.eval()
86
+ total_val_loss = 0.0
87
+
88
+ for metric in metrics.values():
89
+ metric.reset()
90
+
91
+ def to_uint8(images):
92
+ images = (images.clamp(-1, 1) + 1) / 2
93
+ images = (images * 255).type(torch.uint8)
94
+ return images
95
+
96
+ def to_lpips_format(images):
97
+ return images.clamp(-1, 1)
98
+
99
+ pbar = tqdm(val_loader, desc="[Validation & Evaluation]")
100
+ for batch in pbar:
101
+ images = batch["pixel_values"].to(DEVICE)
102
+ input_ids = batch["input_ids"].to(DEVICE)
103
+
104
+ with torch.no_grad():
105
+ # --- TÍNH VALIDATION LOSS ---
106
+ latents = vae.encode(images).latent_dist.mean * vae_scale_factor
107
+ noise = torch.randn_like(latents)
108
+ timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=DEVICE)
109
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
110
+ text_embeds = text_encoder(input_ids)[0]
111
+
112
+ noise_pred = unet(noisy_latents, timesteps, text_embeds).sample
113
+ val_loss = F.mse_loss(noise_pred, noise)
114
+ total_val_loss += val_loss.item()
115
+
116
+ # --- SINH ẢNH (Dùng eval_pipeline) ---
117
+ prompts = eval_pipeline.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
118
+
119
+ generated_output = eval_pipeline(
120
+ prompt=prompts,
121
+ num_inference_steps=num_inference_steps,
122
+ output_type="pt",
123
+ generator=gen
124
+ )
125
+
126
+ generated_images = generated_output.images
127
+ generated_images_norm = (generated_images * 2) - 1
128
+
129
+ # --- CẬP NHẬT METRICS ---
130
+ gt_images_uint8 = to_uint8(images)
131
+ gt_images_lpips = to_lpips_format(images)
132
+ gen_images_uint8 = to_uint8(generated_images_norm)
133
+ gen_images_lpips = to_lpips_format(generated_images_norm)
134
+
135
+ metrics["fid"].update(gt_images_uint8, real=True)
136
+ metrics["fid"].update(gen_images_uint8, real=False)
137
+ metrics["lpips"].update(gt_images_lpips, gen_images_lpips)
138
+ metrics["clip_score"].update(gen_images_uint8, prompts)
139
+
140
+ # --- TRẢ VỀ KẾT QUẢ ---
141
+ results = {
142
+ "val_loss": total_val_loss / len(val_loader),
143
+ "fid": metrics["fid"].compute().item(),
144
+ "lpips": metrics["lpips"].compute().item(),
145
+ "clip_score": metrics["clip_score"].compute().item()
146
+ }
147
+
148
+ return results
149
+
150
+ def main():
151
+ print("--- Giai đoạn 2: Huấn luyện UNet-mini ---")
152
+ start_time_total = time.time()
153
+
154
+ # Khởi tạo Pipeline
155
+ pipeline = MiniDiffusionPipeline(
156
+ base_model_id="runwayml/stable-diffusion-v1-5",
157
+ device=DEVICE,
158
+ config_overrides=TINY_UNET_CONFIG
159
+ )
160
+
161
+ # Tải VAE đã fine-tune
162
+ try:
163
+ pipeline.vae.load_state_dict(torch.load(VAE_PATH, map_location=DEVICE))
164
+ print(f"Tải VAE đã fine-tune thành công từ {VAE_PATH}")
165
+ except Exception as e:
166
+ print(f"Lỗi: Không thể tải VAE từ {VAE_PATH}. {e}")
167
+ print("Vui lòng chạy train_vae.py trước!")
168
+ return
169
+
170
+ pipeline.vae.requires_grad_(False)
171
+ pipeline.text_encoder.requires_grad_(False)
172
+
173
+ unet = pipeline.unet
174
+ vae = pipeline.vae
175
+ text_encoder = pipeline.text_encoder
176
+ tokenizer = pipeline.tokenizer
177
+ noise_scheduler = pipeline.noise_scheduler
178
+ vae_scale_factor = pipeline.config['latent_scale']
179
+
180
+ # Tải Dữ liệu
181
+ train_dataset = SketchDataset(TRAIN_DATA_DIR, tokenizer, IMAGE_SIZE)
182
+ val_dataset = SketchDataset(VAL_DATA_DIR, tokenizer, IMAGE_SIZE)
183
+
184
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
185
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
186
+
187
+ print(f"Đã tải {len(train_dataset)} ảnh train và {len(val_dataset)} ảnh val.")
188
+
189
+ print("Khởi tạo Evaluation Pipeline (một lần)...")
190
+ eval_pipeline = StableDiffusionPipeline(
191
+ unet=unet,
192
+ vae=vae,
193
+ text_encoder=text_encoder,
194
+ tokenizer=tokenizer,
195
+ scheduler=noise_scheduler,
196
+ safety_checker=None,
197
+ feature_extractor=None,
198
+ ).to(DEVICE)
199
+
200
+ eval_pipeline.set_progress_bar_config(disable=True)
201
+
202
+ eval_pipeline.unet.eval()
203
+ eval_pipeline.vae.eval()
204
+ eval_pipeline.text_encoder.eval()
205
+
206
+ gen = torch.Generator(device=DEVICE).manual_seed(42)
207
+
208
+ optimizer = AdamW(unet.parameters(), lr=LEARNING_RATE)
209
+ lr_scheduler = get_scheduler(
210
+ name=pipeline.config['scheduler'],
211
+ optimizer=optimizer,
212
+ num_warmup_steps=500,
213
+ num_training_steps=(len(train_loader) * EPOCHS),
214
+ )
215
+
216
+ metrics = {
217
+ "fid": FrechetInceptionDistance(feature=64).to(DEVICE),
218
+ "lpips": LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(DEVICE),
219
+ "clip_score": CLIPScore(model_name_or_path="openai/clip-vit-base-patch32").to(DEVICE)
220
+ }
221
+
222
+ start_epoch = 0
223
+ history = {
224
+ "train_loss": [], "val_loss": [],
225
+ "fid": [], "lpips": [], "clip_score": []
226
+ }
227
+ best_clip_score = 0.0
228
+
229
+ if os.path.exists(CHECKPOINT_PATH):
230
+ print(f"Phát hiện checkpoint. Đang tải từ {CHECKPOINT_PATH}...")
231
+ try:
232
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
233
+ unet.load_state_dict(checkpoint['unet_state_dict'])
234
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
235
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
236
+ start_epoch = checkpoint['epoch']
237
+ history = checkpoint['history']
238
+ best_clip_score = checkpoint['best_clip_score']
239
+ print(f"Resume training từ epoch {start_epoch}")
240
+ except Exception as e:
241
+ print(f"Lỗi khi tải checkpoint: {e}. Bắt đầu lại từ đầu.")
242
+ start_epoch = 0
243
+ history = {k: [] for k in history}
244
+ best_clip_score = 0.0
245
+ else:
246
+ print("Không tìm thấy checkpoint. Bắt đầu training từ đầu.")
247
+
248
+
249
+ for epoch in range(start_epoch, EPOCHS):
250
+ start_time_epoch = time.time()
251
+ unet.train()
252
+ epoch_train_loss = 0.0
253
+
254
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
255
+ for batch in pbar:
256
+ images = batch["pixel_values"].to(DEVICE)
257
+ input_ids = batch["input_ids"].to(DEVICE)
258
+
259
+ with torch.no_grad():
260
+ latents = vae.encode(images).latent_dist.mean * vae_scale_factor
261
+ text_embeds = text_encoder(input_ids)[0]
262
+
263
+ noise = torch.randn_like(latents)
264
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=DEVICE)
265
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
266
+
267
+ noise_pred = unet(noisy_latents, timesteps, text_embeds).sample
268
+ loss = F.mse_loss(noise_pred, noise)
269
+
270
+ optimizer.zero_grad()
271
+ loss.backward()
272
+ optimizer.step()
273
+ lr_scheduler.step()
274
+
275
+ epoch_train_loss += loss.item()
276
+ pbar.set_postfix({"Loss": loss.item()})
277
+
278
+ avg_train_loss = epoch_train_loss / len(train_loader)
279
+ history["train_loss"].append(avg_train_loss)
280
+
281
+ # ---Chạy Đánh giá (Evaluation) ---
282
+
283
+ eval_results = evaluate(
284
+ eval_pipeline, gen, val_loader, metrics,
285
+ unet, vae, text_encoder, noise_scheduler,
286
+ vae_scale_factor, NUM_INFERENCE_STEPS
287
+ )
288
+
289
+ history["val_loss"].append(eval_results["val_loss"])
290
+ history["fid"].append(eval_results["fid"])
291
+ history["lpips"].append(eval_results["lpips"])
292
+ history["clip_score"].append(eval_results["clip_score"])
293
+
294
+ epoch_time_min = (time.time() - start_time_epoch) / 60
295
+
296
+ print(f"\n--- Epoch {epoch+1}/{EPOCHS} Results (Thời gian: {epoch_time_min:.2f} phút) ---")
297
+ print(f" Train Loss: {avg_train_loss:.6f}")
298
+ print(f" Val Loss: {eval_results['val_loss']:.6f}")
299
+ print(f" LPIPS: {eval_results['lpips']:.4f} (↓)")
300
+ print(f" FID: {eval_results['fid']:.4f} (↓)")
301
+ print(f" CLIP Score: {eval_results['clip_score']:.4f} (↑)")
302
+
303
+ if eval_results['clip_score'] > best_clip_score:
304
+ best_clip_score = eval_results['clip_score']
305
+ torch.save(unet.state_dict(), SAVE_UNET_PATH)
306
+ print(f"Đã lưu UNet *tốt nhất* mới tại {SAVE_UNET_PATH} (CLIP Score: {best_clip_score:.4f})")
307
+
308
+ print(f"Đang lưu checkpoint cuối cùng tại {CHECKPOINT_PATH}...")
309
+ checkpoint = {
310
+ 'epoch': epoch + 1,
311
+ 'unet_state_dict': unet.state_dict(),
312
+ 'optimizer_state_dict': optimizer.state_dict(),
313
+ 'lr_scheduler_state_dict': lr_scheduler.state_dict(),
314
+ 'history': history,
315
+ 'best_clip_score': best_clip_score
316
+ }
317
+ torch.save(checkpoint, CHECKPOINT_PATH)
318
+
319
+
320
+ total_time_min = (time.time() - start_time_total) / 60
321
+ print(f"\n--- Hoàn thành Giai đoạn 2 ---")
322
+ print(f"Tổng thời gian chạy (phiên này): {total_time_min:.2f} phút")
323
+ print(f"UNet đã train (tốt nhất) được lưu tại: {SAVE_UNET_PATH}")
324
+
325
+ if history['train_loss']:
326
+ plot_metrics(history, "unet_metrics_plot_v1.png")
327
+ else:
328
+ print("Không có dữ liệu history để vẽ biểu đồ.")
329
+
330
+ if __name__ == "__main__":
331
+ main()
Agent_Diffusion/train_vae.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.utils.data import DataLoader
4
+ from torchvision import transforms
5
+ from torch.optim import AdamW
6
+ from PIL import Image
7
+ import os
8
+ import time
9
+ import matplotlib.pyplot as plt
10
+ from tqdm import tqdm
11
+
12
+
13
+ from stable_diffusion import MiniDiffusionPipeline
14
+ from dataset import SketchDataset
15
+
16
+ # --- Cấu hình ---
17
+ TRAIN_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\train"
18
+ VAL_DATA_DIR = r"C:\Users\Admin\Desktop\scientific research\dataset\val"
19
+ IMAGE_SIZE = 128
20
+ EPOCHS = 36
21
+ BATCH_SIZE = 16
22
+ LEARNING_RATE = 1e-5
23
+ SAVE_PATH = "vae-finetuned.safetensors"
24
+ CHECKPOINT_PATH = "vae_latest_checkpoint.pth"
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ def plot_losses(train_losses, val_losses, filename="vae_loss_plot_v1.png"):
28
+ plt.figure(figsize=(10, 5))
29
+ plt.plot(train_losses, label="Train Loss")
30
+ plt.plot(val_losses, label="Validation Loss")
31
+ plt.title("VAE Fine-tuning Loss")
32
+ plt.xlabel("Epoch")
33
+ plt.ylabel("MSE Loss")
34
+ plt.grid()
35
+ plt.ylim(0.085,0.12)
36
+ plt.legend()
37
+ plt.savefig(filename)
38
+ print(f"Đã lưu biểu đồ loss tại {filename}")
39
+
40
+ def main():
41
+ print("--- Giai đoạn 1: Fine-tuning VAE ---")
42
+ pipeline = MiniDiffusionPipeline(
43
+ base_model_id="runwayml/stable-diffusion-v1-5",
44
+ vae_model_id="stabilityai/sd-vae-ft-mse",
45
+ device=DEVICE
46
+ )
47
+ vae = pipeline.vae
48
+ tokenizer = pipeline.tokenizer
49
+ vae_scale_factor = pipeline.config['latent_scale']
50
+
51
+ train_dataset = SketchDataset(TRAIN_DATA_DIR, tokenizer, IMAGE_SIZE)
52
+ val_dataset = SketchDataset(VAL_DATA_DIR, tokenizer, IMAGE_SIZE)
53
+
54
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
55
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
56
+
57
+ print(f"Đã tải {len(train_dataset)} ảnh train và {len(val_dataset)} ảnh val.")
58
+
59
+ optimizer = AdamW(vae.parameters(), lr=LEARNING_RATE)
60
+
61
+ start_epoch = 0
62
+ train_losses = []
63
+ val_losses = []
64
+ best_val_loss = float('inf')
65
+
66
+ if os.path.exists(CHECKPOINT_PATH):
67
+ print(f"Phát hiện checkpoint. Đang tải từ {CHECKPOINT_PATH}...")
68
+ try:
69
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
70
+
71
+ vae.load_state_dict(checkpoint['vae_state_dict'])
72
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
73
+ start_epoch = checkpoint['epoch'] # Đây là epoch *tiếp theo*
74
+ train_losses = checkpoint['train_losses']
75
+ val_losses = checkpoint['val_losses']
76
+ best_val_loss = checkpoint['best_val_loss']
77
+
78
+ print(f"Resume training từ epoch {start_epoch}")
79
+ except Exception as e:
80
+ print(f"Lỗi khi tải checkpoint: {e}. Bắt đầu lại từ đầu.")
81
+ start_epoch = 0
82
+ train_losses = []
83
+ val_losses = []
84
+ best_val_loss = float('inf')
85
+ else:
86
+ print("Không tìm thấy checkpoint. Bắt đầu training từ đầu.")
87
+
88
+ start_time = time.time()
89
+
90
+ for epoch in range(start_epoch, EPOCHS):
91
+ vae.train()
92
+ epoch_train_loss = 0.0
93
+
94
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
95
+ for batch in pbar:
96
+ images = batch["pixel_values"].to(DEVICE)
97
+
98
+ posterior = vae.encode(images).latent_dist
99
+ latents = posterior.mean * vae_scale_factor
100
+
101
+ reconstructions = vae.decode(latents / vae_scale_factor).sample
102
+
103
+ loss = F.mse_loss(reconstructions, images)
104
+
105
+ optimizer.zero_grad()
106
+ loss.backward()
107
+ optimizer.step()
108
+
109
+ epoch_train_loss += loss.item()
110
+ pbar.set_postfix({"Loss": loss.item()})
111
+
112
+ avg_train_loss = epoch_train_loss / len(train_loader)
113
+ train_losses.append(avg_train_loss)
114
+
115
+ vae.eval()
116
+ epoch_val_loss = 0.0
117
+ with torch.no_grad():
118
+ pbar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
119
+ for batch in pbar_val:
120
+ images = batch["pixel_values"].to(DEVICE)
121
+ posterior = vae.encode(images).latent_dist
122
+ latents = posterior.mean * vae_scale_factor
123
+ reconstructions = vae.decode(latents / vae_scale_factor).sample
124
+ loss = F.mse_loss(reconstructions, images)
125
+ epoch_val_loss += loss.item()
126
+
127
+ avg_val_loss = epoch_val_loss / len(val_loader)
128
+ val_losses.append(avg_val_loss)
129
+
130
+ print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.6f} - Val Loss: {avg_val_loss:.6f}")
131
+
132
+ if avg_val_loss < best_val_loss:
133
+ best_val_loss = avg_val_loss
134
+ torch.save(vae.state_dict(), SAVE_PATH)
135
+ print(f"Đã lưu VAE *tốt nhất* mới tại {SAVE_PATH} (Val Loss: {best_val_loss:.6f})")
136
+
137
+ print(f"Đang lưu checkpoint cuối cùng tại {CHECKPOINT_PATH}...")
138
+ checkpoint = {
139
+ 'epoch': epoch + 1,
140
+ 'vae_state_dict': vae.state_dict(),
141
+ 'optimizer_state_dict': optimizer.state_dict(),
142
+ 'train_losses': train_losses,
143
+ 'val_losses': val_losses,
144
+ 'best_val_loss': best_val_loss
145
+ }
146
+ torch.save(checkpoint, CHECKPOINT_PATH)
147
+
148
+
149
+ end_time = time.time()
150
+ total_time_min = (end_time - start_time) / 60
151
+ print(f"\n--- Hoàn thành Giai đoạn 1 ---")
152
+ print(f"Tổng thời gian chạy (phiên này): {total_time_min:.2f} phút")
153
+ print(f"VAE đã fine-tune (tốt nhất) được lưu tại: {SAVE_PATH}")
154
+
155
+ if train_losses and val_losses:
156
+ plot_losses(train_losses, val_losses, "vae_loss_plot.png")
157
+
158
+ if __name__ == "__main__":
159
+ main()
Agent_Diffusion/unet-mini.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:291a53f1a95a6893a7c06d9e8cbe06667b9e2bd53bcbaa87636b509ba4529821
3
+ size 208784907
Agent_Diffusion/vae-finetuned.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd73806bd9362c4d5ad89e5017178db324d5e49d2efe9e0dd04c1a2871395b32
3
+ size 334713859
Conservative Q-learning/Agen1_training.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import numpy as np
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from transformers import T5Tokenizer, T5EncoderModel
6
+
7
+ # Import từ các file bạn đã cung cấp
8
+ from cql_agent import CQLAgent
9
+ from cql_utils import DataNormalizer
10
+
11
+ # --- 1. ĐỊNH NGHĨA DATASET ---
12
+ class SketchOfflineDataset(Dataset):
13
+ def __init__(self, file_path):
14
+ self.data = []
15
+ with open(file_path, 'r', encoding='utf-8') as f:
16
+ for line in f:
17
+ self.data.append(json.loads(line))
18
+
19
+ def __len__(self):
20
+ return len(self.data)
21
+
22
+ def __getitem__(self, idx):
23
+ item = self.data[idx]
24
+ return (
25
+ np.array(item['state'], dtype=np.float32),
26
+ item['action'],
27
+ item['reward'],
28
+ np.array(item['next_state'], dtype=np.float32),
29
+ item['done']
30
+ )
31
+
32
+ # --- 2. HÀM HUẤN LUYỆN ---
33
+ def train_agent_1(epoch):
34
+ print("🚀 Khởi động quá trình huấn luyện Agent Chính...")
35
+
36
+ # Cấu hình
37
+ FILE_PATH = "massive_diverse_sketch_dataset.json"
38
+ STATE_DIM = 768 # Tương ứng đầu ra t5-base
39
+ ACTION_DIM = 3 # 0: Chat, 1: Sketch, 2: Reject
40
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ # Load Dữ liệu
43
+ full_dataset = SketchOfflineDataset(FILE_PATH)
44
+ train_size = int(0.8 * len(full_dataset))
45
+ test_size = len(full_dataset) - train_size
46
+ train_dataset, _ = torch.utils.data.random_split(full_dataset, [train_size, test_size])
47
+ train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
48
+
49
+ # Khởi tạo Agent (Sử dụng Discrete CQL vì Action là 0, 1, 2)
50
+ agent = CQLAgent(state_dim=STATE_DIM, action_dim=ACTION_DIM, is_continuous=False, device=DEVICE)
51
+
52
+ # Chuẩn hóa dữ liệu (Normalization)
53
+ all_states = np.array([item['state'] for item in full_dataset.data])
54
+ agent.normalizer.fit(all_states)
55
+
56
+ # Vòng lặp huấn luyện
57
+ epochs = epoch
58
+ for epoch in range(epochs):
59
+ total_loss = 0
60
+ for batch in train_loader:
61
+ # Batch: states, actions, rewards, next_states, dones
62
+ metrics = agent.train_step(batch)
63
+ total_loss += metrics['critic_loss']
64
+
65
+ print(f"Epoch {epoch+1}/{epochs} | Critic Loss: {total_loss/len(train_loader):.4f}")
66
+
67
+ # Lưu mô hình
68
+ agent.save_model("saved_agent_1")
69
+ return agent
70
+
71
+ # --- 3. HÀM KIỂM THỬ (TEST) ---
72
+ class Agent1Inference:
73
+ def __init__(self, model_path="saved_agent_1"):
74
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
75
+
76
+ # Load T5 để phân tích prompt
77
+ self.tokenizer = T5Tokenizer.from_pretrained("t5-base")
78
+ self.encoder = T5EncoderModel.from_pretrained("t5-base").to(self.device)
79
+
80
+ # Load CQL để ra quyết định
81
+ self.agent = CQLAgent(state_dim=768, action_dim=3, is_continuous=False, device=self.device)
82
+ self.agent.load_model(model_path)
83
+
84
+ def get_action(self, text_prompt):
85
+ # Bước 1: T5 phân tích -> Embedding
86
+ inputs = self.tokenizer(text_prompt, return_tensors="pt", padding=True).to(self.device)
87
+ with torch.no_grad():
88
+ embedding = self.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy().flatten()
89
+
90
+ # Bước 2: CQL ra quyết định
91
+ action_idx = self.agent.select_action(embedding)
92
+
93
+ mapping = {0: "Kích hoạt Agent Giao tiếp", 1: "Kích hoạt Agent Vẽ ảnh (Sketch)", 2: "Từ chối/Yêu cầu làm rõ"}
94
+ return mapping[action_idx]
95
+
96
+ # --- CHẠY CHƯƠNG TRÌNH ---
97
+ if __name__ == "__main__":
98
+ # 1. Train
99
+ trained_agent = train_agent_1(1000)
100
+
101
+ # 2. Test thử mô hình đã lưu
102
+ print("\n--- BẮT ĐẦU TEST AGENT CHÍNH ---")
103
+ tester = Agent1Inference("saved_agent_1")
104
+
105
+ test_prompts = [
106
+ "Vẽ cho mình một bức chân dung cụ ông sketch",
107
+ "Chào bot, hôm nay bạn thế nào?",
108
+ "Hãy vẽ một bông hoa hồng bằng màu dầu rực rỡ",
109
+ "Kí họa nhanh khuôn mặt cô gái đang cười"
110
+ ]
111
+
112
+ for p in test_prompts:
113
+ action = tester.get_action(p)
114
+ print(f"User: {p} \n=> Agent 1 quyết định: {action}\n")
Conservative Q-learning/cql_agent.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import copy
6
+ import os
7
+ from cql_utils import MLP, TanhGaussianPolicy, DataNormalizer
8
+
9
+ class CQLAgent:
10
+ def __init__(
11
+ self,
12
+ state_dim,
13
+ action_dim,
14
+ device='cuda' if torch.cuda.is_available() else 'cpu',
15
+ is_continuous=True,
16
+ hidden_dim=256,
17
+ lr=3e-4,
18
+ cql_weight=1.0,
19
+ temp=1.0,
20
+ gamma=0.99,
21
+ tau=0.005
22
+ ):
23
+ self.state_dim = state_dim
24
+ self.action_dim = action_dim
25
+ self.device = torch.device(device)
26
+ self.is_continuous = is_continuous
27
+ self.cql_weight = cql_weight
28
+ self.temp = temp # logsumexp
29
+ self.gamma = gamma
30
+ self.tau = tau # update coefficient
31
+
32
+ self.normalizer = DataNormalizer(state_dim)
33
+
34
+ if self.is_continuous:
35
+ self.actor = TanhGaussianPolicy(state_dim, action_dim, hidden_dim).to(self.device)
36
+ self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
37
+
38
+ # Critic (Continuous: input = State + action)
39
+ # using 2 Critics to reduce overstimation (Double Q-learning)
40
+ self.critic_1 = MLP(state_dim + action_dim, 1, hidden_dim).to(self.device)
41
+ self.critic_2 = MLP(state_dim + action_dim, 1, hidden_dim).to(self.device)
42
+ self.target_critic_1 = copy.deepcopy(self.critic_1)
43
+ self.target_critic_2 = copy.deepcopy(self.critic_2)
44
+
45
+ else:
46
+ # Discrete: Actor is argmax Q, Critic is Q-network (Input = state, output = Q of actions)
47
+ self.critic_1 = MLP(state_dim, action_dim, hidden_dim).to(self.device)
48
+ self.target_critic_1 = copy.deepcopy(self.critic_1)
49
+ # we alo can use 2 network to more robust but now we just use 1 network to check and we can update it later
50
+
51
+ params = list(self.critic_1.parameters())
52
+ if self.is_continuous:
53
+ params += list(self.critic_2.parameters())
54
+ self.critic_optimizer = optim.Adam(params, lr=lr)
55
+
56
+ # auto fine-tune Alpha (Entropy) for SAC (just Continuous)
57
+ if self.is_continuous:
58
+ self.target_entropy = -action_dim
59
+ self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
60
+ self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
61
+
62
+ def select_action(self, state, evaluate=True):
63
+ if not isinstance(state, np.ndarray):
64
+ state = np.array(state)
65
+
66
+ state = self.normalizer.normalize(state)
67
+ state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
68
+
69
+ with torch.no_grad():
70
+ if self.is_continuous:
71
+ action, _ = self.actor(state)
72
+ return action.cpu().data.numpy().flatten()
73
+ else:
74
+ q_values = self.critic_1(state)
75
+ return q_values.argmax(dim=1).cpu().data.numpy().item()
76
+
77
+ def _compute_cql_loss(self, q_values_pred, states, actions):
78
+ if self.is_continuous:
79
+ batch_size = states.shape[0]
80
+ num_samples = 10
81
+
82
+ # Taking random samples
83
+ random_actions = torch.FloatTensor(batch_size * num_samples, self.action_dim).uniform_(-1, 1).to(self.device)
84
+
85
+ # Taking actions from current Policy
86
+ curr_actions, curr_log_pi = self.actor(states.repeat_interleave(num_samples, dim=0))
87
+
88
+ # Calculating Q for each these samples
89
+ states_repeated = states.repeat_interleave(num_samples, dim=0)
90
+
91
+ # grafting states and random actions and current actions
92
+ q1_rand = self.critic_1(torch.cat([states_repeated, random_actions], dim=1))
93
+ q1_curr = self.critic_1(torch.cat([states_repeated, curr_actions], dim=1))
94
+
95
+ # gathering: [batch, num_samples, 1]
96
+ q1_rand = q1_rand.view(batch_size, num_samples, 1)
97
+ q1_curr = q1_curr.view(batch_size, num_samples, 1)
98
+
99
+ # merging to calculating LogSumExp
100
+ cat_q1 = torch.cat([q1_rand, q1_curr], dim=1)
101
+
102
+ # CQL Loss 1: log(sum(exp(Q_ood)))
103
+ cql_loss_1 = torch.logsumexp(cat_q1 / self.temp, dim=1).mean() * self.temp
104
+ # CQL Loss 2: - Q_data (Maximizing Q-value of data samples)
105
+ cql_loss_2 = q_values_pred.mean()
106
+
107
+ return (cql_loss_1 - cql_loss_2) * self.cql_weight
108
+
109
+ else:
110
+ # q_values_pred shape: [batch, action_dim] (calculated for all action space)
111
+
112
+ # CQL Loss 1: log(sum(exp(Q_all))) - calculate for all action space
113
+ cql_loss_1 = torch.logsumexp(q_values_pred / self.temp, dim=1).mean() * self.temp
114
+
115
+ # CQL Loss 2: - Q_data (takeing Q-value at real action in batch)
116
+ # actions shape: [batch, 1]
117
+ q_data = q_values_pred.gather(1, actions.long())
118
+ cql_loss_2 = q_data.mean()
119
+
120
+ return (cql_loss_1 - cql_loss_2) * self.cql_weight
121
+
122
+ def train_step(self, batch):
123
+ states, actions, rewards, next_states, dones = batch
124
+
125
+ states = torch.FloatTensor(states).to(self.device)
126
+ actions = torch.FloatTensor(actions).to(self.device)
127
+ rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
128
+ next_states = torch.FloatTensor(next_states).to(self.device)
129
+ dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
130
+
131
+ # Update Critic (Q-Functions)
132
+ with torch.no_grad():
133
+ if self.is_continuous:
134
+ next_actions, next_log_pi = self.actor(next_states)
135
+ q1_target = self.target_critic_1(torch.cat([next_states, next_actions], dim=1))
136
+ q2_target = self.target_critic_2(torch.cat([next_states, next_actions], dim=1))
137
+ min_q_target = torch.min(q1_target, q2_target)
138
+
139
+ # Soft Actor-Critic Target (Having entropy)
140
+ alpha = self.log_alpha.exp()
141
+ q_target = rewards + (1 - dones) * self.gamma * (min_q_target - alpha * next_log_pi)
142
+ else:
143
+ # DQN Target
144
+ # Double DQN logic
145
+ q_next = self.target_critic_1(next_states)
146
+ max_q_next, _ = torch.max(q_next, dim=1, keepdim=True)
147
+ q_target = rewards + (1 - dones) * self.gamma * max_q_next
148
+
149
+ # calculating Q hiện tại
150
+ if self.is_continuous:
151
+ q1_pred = self.critic_1(torch.cat([states, actions], dim=1))
152
+ q2_pred = self.critic_2(torch.cat([states, actions], dim=1))
153
+ mse_loss = F.mse_loss(q1_pred, q_target) + F.mse_loss(q2_pred, q_target)
154
+
155
+ # adding CQL Loss
156
+ cql_loss = self._compute_cql_loss(q1_pred, states, actions) + \
157
+ self._compute_cql_loss(q2_pred, states, actions)
158
+ else:
159
+ q_all = self.critic_1(states)
160
+ q_pred = q_all.gather(1, actions.long())
161
+ mse_loss = F.mse_loss(q_pred, q_target)
162
+
163
+ # Thêm CQL Loss
164
+ cql_loss = self._compute_cql_loss(q_all, states, actions)
165
+
166
+ total_critic_loss = mse_loss + cql_loss
167
+
168
+ self.critic_optimizer.zero_grad()
169
+ total_critic_loss.backward()
170
+ self.critic_optimizer.step()
171
+
172
+ # Update Actor (just Continuous)
173
+ actor_loss_val = 0
174
+ if self.is_continuous:
175
+ new_actions, log_pi = self.actor(states)
176
+ q1_new = self.critic_1(torch.cat([states, new_actions], dim=1))
177
+ q2_new = self.critic_2(torch.cat([states, new_actions], dim=1))
178
+ min_q_new = torch.min(q1_new, q2_new)
179
+
180
+ # SAC Actor Loss: Maximize (Q - alpha * log_prob) -> Minimize (alpha * log_prob - Q)
181
+ actor_loss = (alpha * log_pi - min_q_new).mean()
182
+
183
+ self.actor_optimizer.zero_grad()
184
+ actor_loss.backward()
185
+ self.actor_optimizer.step()
186
+
187
+ # Update Alpha (Temperature)
188
+ alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
189
+ self.alpha_optimizer.zero_grad()
190
+ alpha_loss.backward()
191
+ self.alpha_optimizer.step()
192
+
193
+ actor_loss_val = actor_loss.item()
194
+
195
+ # Soft Update Target Networks ===
196
+ self._soft_update(self.critic_1, self.target_critic_1)
197
+ if self.is_continuous:
198
+ self._soft_update(self.critic_2, self.target_critic_2)
199
+
200
+ return {
201
+ "critic_loss": total_critic_loss.item(),
202
+ "cql_loss": cql_loss.item(),
203
+ "actor_loss": actor_loss_val
204
+ }
205
+
206
+ def _soft_update(self, local_model, target_model):
207
+ for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
208
+ target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)
209
+
210
+ def save_model(self, path):
211
+ os.makedirs(path, exist_ok=True)
212
+ state_dict = {
213
+ 'is_continuous': self.is_continuous,
214
+ 'critic_1': self.critic_1.state_dict(),
215
+ }
216
+ if self.is_continuous:
217
+ state_dict.update({
218
+ 'critic_2': self.critic_2.state_dict(),
219
+ 'actor': self.actor.state_dict(),
220
+ 'log_alpha': self.log_alpha
221
+ })
222
+
223
+ torch.save(state_dict, os.path.join(path, "cql_model.pth"))
224
+ self.normalizer.save(os.path.join(path, "normalizer.pkl"))
225
+ print(f"Model saved to {path}")
226
+
227
+ def load_model(self, path):
228
+ model_path = os.path.join(path, "cql_model.pth")
229
+ if not os.path.exists(model_path):
230
+ print("No model found!")
231
+ return
232
+
233
+ checkpoint = torch.load(model_path, map_location=self.device)
234
+ self.critic_1.load_state_dict(checkpoint['critic_1'])
235
+
236
+ if self.is_continuous and checkpoint['is_continuous']:
237
+ self.critic_2.load_state_dict(checkpoint['critic_2'])
238
+ self.actor.load_state_dict(checkpoint['actor'])
239
+ self.log_alpha = checkpoint['log_alpha']
240
+
241
+ self.normalizer.load(os.path.join(path, "normalizer.pkl"))
242
+ print(f"Model loaded from {path}")
Conservative Q-learning/cql_utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.distributions import Normal
5
+ import numpy as np
6
+ import pickle
7
+ import os
8
+
9
+ class DataNormalizer:
10
+ def __init__(self, state_dim):
11
+ self.mean = np.zeros(state_dim)
12
+ self.std = np.zeros(state_dim)
13
+ self.std[self.std < 1e-6] = 1.0
14
+
15
+ def fit(self, states):
16
+ self.mean = np.mean(states, axis=0)
17
+ self.std = np.std(states, axis=0)
18
+
19
+ def normalize(self, states):
20
+ return (states - self.mean) / self.std
21
+
22
+ def denormalize(self, states):
23
+ return states * self.std + self.mean
24
+
25
+ def save(self, path):
26
+ with open(path, 'wb') as f:
27
+ pickle.dump({'mean': self.mean, 'std': self.std}, f)
28
+
29
+ def load(self, path):
30
+ with open(path, 'rb') as f:
31
+ data = pickle.load(f)
32
+ self.mean = data['mean']
33
+ self.std = data['std']
34
+
35
+ class MLP(nn.Module):
36
+ def __init__(self, input_dim, output_dim, hidden_dim=256, n_layers=2):
37
+ super().__init__()
38
+ layers = []
39
+ layers.append(nn.Linear(input_dim, hidden_dim))
40
+ layers.append(nn.ReLU())
41
+ for _ in range(n_layers - 1):
42
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
43
+ layers.append(nn.ReLU())
44
+
45
+ layers.append(nn.Linear(hidden_dim, output_dim))
46
+ self.net = nn.Sequential(*layers)
47
+
48
+ def forward(self, x):
49
+ return self.net(x)
50
+
51
+ class TanhGaussianPolicy(nn.Module):
52
+ def __init__(self, state_dim, action_dim, hidden_dim=256):
53
+ super().__init__()
54
+ self.base = MLP(state_dim, hidden_dim, hidden_dim)
55
+ self.mu_head = nn.Linear(hidden_dim, action_dim)
56
+ self.log_std_head = nn.Linear(hidden_dim, action_dim)
57
+
58
+ def forward(self, state):
59
+ # x = self.base.net[:-1](state) # getting feature from MLP base
60
+ x = self.base(state)
61
+ x = F.relu(x)
62
+
63
+ mu = self.mu_head(x)
64
+ log_std = self.log_std_head(x)
65
+ log_std = torch.clamp(log_std, -20, 2)
66
+
67
+ std = torch.exp(log_std)
68
+ dist = Normal(mu, std)
69
+
70
+ # Reparameterization trick: a = mu + std * epsilon
71
+ x_t = dist.rsample()
72
+ action = torch.tanh(x_t) # force to [-1, 1]
73
+
74
+ #calculating log probability
75
+ log_prob = dist.log_prob(x_t)
76
+
77
+ log_prob -= torch.log(1 - action.pow(2) + 1e-6)
78
+ log_prob = log_prob.sum(1, keepdim=True)
79
+
80
+ return action, log_prob
81
+
82
+
Conservative Q-learning/saved_agent_1/cql_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8ff7062d0d7b6d0ff647071ef0d21ccb7121770d6c51aeb39ed8ad10f882758
3
+ size 1056877
Conservative Q-learning/saved_agent_1/normalizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac608983b3f70024d8ec8b7efd1e58239d70a56f0c926d091a6a5267b56adcb3
3
+ size 12491
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit App - Beautiful UI for CQL Chatbot System
3
+ """
4
+ import streamlit as st
5
+ import plotly.graph_objects as go
6
+ from datetime import datetime
7
+ import config
8
+ from chatbot_engine import CQLChatbot
9
+ from memory_manager import MemoryManager
10
+
11
+
12
+ # Page configuration
13
+ st.set_page_config(
14
+ page_title=config.APP_TITLE,
15
+ page_icon=config.APP_ICON,
16
+ layout="wide",
17
+ initial_sidebar_state="expanded"
18
+ )
19
+
20
+ # Custom CSS for beautiful Dark Galaxy UI
21
+ st.markdown("""
22
+ <style>
23
+ /* Main container - Dark Galaxy Background */
24
+ .main {
25
+ background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%);
26
+ }
27
+
28
+ /* Chat container */
29
+ .stChatMessage {
30
+ background-color: rgba(30, 30, 46, 0.8);
31
+ border-radius: 16px;
32
+ padding: 18px;
33
+ margin: 12px 0;
34
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
35
+ backdrop-filter: blur(10px);
36
+ }
37
+
38
+ /* User message - Purple gradient */
39
+ .stChatMessage[data-testid="user-message"] {
40
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
41
+ color: white;
42
+ border: 1px solid rgba(255, 255, 255, 0.1);
43
+ }
44
+
45
+ /* Assistant message - Dark with blue accent */
46
+ .stChatMessage[data-testid="assistant-message"] {
47
+ background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%);
48
+ border-left: 4px solid #667eea;
49
+ color: #e2e8f0;
50
+ }
51
+
52
+ /* Sidebar - Dark theme */
53
+ section[data-testid="stSidebar"] {
54
+ background: linear-gradient(180deg, #1a1a2e 0%, #16213e 100%);
55
+ border-right: 1px solid rgba(102, 126, 234, 0.2);
56
+ }
57
+
58
+ section[data-testid="stSidebar"] * {
59
+ color: #e2e8f0 !important;
60
+ }
61
+
62
+ /* Headers - Light text */
63
+ h1, h2, h3, h4, h5, h6 {
64
+ color: #f7fafc !important;
65
+ text-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
66
+ }
67
+
68
+ /* Paragraph text */
69
+ p, span, div {
70
+ color: #cbd5e0;
71
+ }
72
+
73
+ /* Buttons - Purple gradient */
74
+ .stButton>button {
75
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
76
+ color: white;
77
+ border: none;
78
+ border-radius: 12px;
79
+ padding: 12px 24px;
80
+ font-weight: 600;
81
+ transition: all 0.3s ease;
82
+ box-shadow: 0 4px 8px rgba(102, 126, 234, 0.3);
83
+ }
84
+
85
+ .stButton>button:hover {
86
+ transform: translateY(-2px);
87
+ box-shadow: 0 6px 16px rgba(102, 126, 234, 0.5);
88
+ }
89
+
90
+ /* Input box - Dark with glow */
91
+ .stTextInput>div>div>input, .stChatInput>div>div>input {
92
+ border-radius: 12px;
93
+ border: 2px solid rgba(102, 126, 234, 0.3);
94
+ padding: 14px;
95
+ background-color: rgba(30, 30, 46, 0.6);
96
+ color: #e2e8f0;
97
+ transition: all 0.3s ease;
98
+ }
99
+
100
+ .stTextInput>div>div>input:focus, .stChatInput>div>div>input:focus {
101
+ border-color: #667eea;
102
+ box-shadow: 0 0 12px rgba(102, 126, 234, 0.4);
103
+ }
104
+
105
+ /* Slider */
106
+ .stSlider>div>div>div {
107
+ background-color: rgba(102, 126, 234, 0.3);
108
+ }
109
+
110
+ /* Metrics - Dark cards */
111
+ .stMetric {
112
+ background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%);
113
+ padding: 18px;
114
+ border-radius: 12px;
115
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.3);
116
+ border: 1px solid rgba(102, 126, 234, 0.2);
117
+ }
118
+
119
+ .stMetric label {
120
+ color: #a0aec0 !important;
121
+ }
122
+
123
+ .stMetric [data-testid="stMetricValue"] {
124
+ color: #667eea !important;
125
+ }
126
+
127
+ /* Action badges - Glowing */
128
+ .action-badge {
129
+ display: inline-block;
130
+ padding: 6px 16px;
131
+ border-radius: 20px;
132
+ font-size: 0.85em;
133
+ font-weight: 600;
134
+ margin-left: 8px;
135
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3);
136
+ }
137
+
138
+ .action-0 {
139
+ background: linear-gradient(135deg, #4299e1 0%, #3182ce 100%);
140
+ color: white;
141
+ }
142
+
143
+ .action-1 {
144
+ background: linear-gradient(135deg, #9f7aea 0%, #805ad5 100%);
145
+ color: white;
146
+ }
147
+
148
+ .action-2 {
149
+ background: linear-gradient(135deg, #ed8936 0%, #dd6b20 100%);
150
+ color: white;
151
+ }
152
+
153
+ /* Divider */
154
+ hr {
155
+ border-color: rgba(102, 126, 234, 0.2);
156
+ }
157
+
158
+ /* Expander */
159
+ .streamlit-expanderHeader {
160
+ background-color: rgba(30, 30, 46, 0.6);
161
+ color: #e2e8f0 !important;
162
+ border-radius: 8px;
163
+ }
164
+
165
+ /* Caption text */
166
+ .stCaption {
167
+ color: #a0aec0 !important;
168
+ }
169
+
170
+ /* Success/Info/Warning boxes */
171
+ .stSuccess, .stInfo, .stWarning {
172
+ background-color: rgba(30, 30, 46, 0.8);
173
+ border-radius: 8px;
174
+ color: #e2e8f0;
175
+ }
176
+ </style>
177
+ """, unsafe_allow_html=True)
178
+
179
+
180
+ # Initialize session state
181
+ def init_session_state():
182
+ """Initialize Streamlit session state"""
183
+ # Initialize memory manager
184
+ if 'memory_manager' not in st.session_state:
185
+ st.session_state.memory_manager = MemoryManager()
186
+ st.session_state.memory_manager.create_new_session()
187
+
188
+ if 'chatbot' not in st.session_state:
189
+ with st.spinner('🚀 Đang khởi tạo CQL Chatbot System...'):
190
+ st.session_state.chatbot = CQLChatbot(
191
+ memory_manager=st.session_state.memory_manager
192
+ )
193
+
194
+ if 'messages' not in st.session_state:
195
+ st.session_state.messages = []
196
+
197
+ if 'action_history' not in st.session_state:
198
+ st.session_state.action_history = []
199
+
200
+
201
+ def display_action_badge(action: int, action_name: str):
202
+ """Display action badge with color coding"""
203
+ badge_html = f'<span class="action-badge action-{action}">{config.ACTION_DESCRIPTIONS[action]}</span>'
204
+ return badge_html
205
+
206
+
207
+ def create_action_distribution_chart(distribution: dict, chart_id: str = "action_dist"):
208
+ """Create a pie chart for action distribution"""
209
+ if not distribution or sum(distribution.values()) == 0:
210
+ return None
211
+
212
+ labels = list(distribution.keys())
213
+ values = list(distribution.values())
214
+
215
+ fig = go.Figure(data=[go.Pie(
216
+ labels=labels,
217
+ values=values,
218
+ hole=0.4,
219
+ marker=dict(colors=['#4A90E2', '#7b1fa2', '#f57c00']),
220
+ textinfo='label+percent',
221
+ textfont=dict(size=12)
222
+ )])
223
+
224
+ fig.update_layout(
225
+ title="Phân bố hành động",
226
+ height=300,
227
+ showlegend=True,
228
+ margin=dict(l=20, r=20, t=40, b=20),
229
+ # Add unique identifier
230
+ updatemenus=[],
231
+ sliders=[]
232
+ )
233
+
234
+ return fig
235
+
236
+
237
+ def create_q_values_chart(q_values: list, chart_id: str = "q_values"):
238
+ """Create a bar chart for Q-values"""
239
+ actions = [config.ACTION_DESCRIPTIONS[i] for i in range(len(q_values))]
240
+
241
+ fig = go.Figure(data=[go.Bar(
242
+ x=actions,
243
+ y=q_values,
244
+ marker=dict(
245
+ color=q_values,
246
+ colorscale='Viridis',
247
+ showscale=True
248
+ ),
249
+ text=[f'{v:.2f}' for v in q_values],
250
+ textposition='auto',
251
+ )])
252
+
253
+ fig.update_layout(
254
+ title="Q-Values (Giá trị hành động)",
255
+ xaxis_title="Hành động",
256
+ yaxis_title="Q-Value",
257
+ height=300,
258
+ showlegend=False,
259
+ margin=dict(l=20, r=20, t=40, b=20),
260
+ # Add unique identifier
261
+ updatemenus=[],
262
+ sliders=[]
263
+ )
264
+
265
+ return fig
266
+
267
+
268
+ def main():
269
+ """Main Streamlit app"""
270
+
271
+ # Initialize
272
+ init_session_state()
273
+
274
+ # Fixed temperature (không hiển thị cho user)
275
+ FIXED_TEMPERATURE = 0.7
276
+
277
+ # Header
278
+ st.title(f"{config.APP_ICON} Chatbot Wisdom")
279
+ st.markdown("### 🧠 Hệ thống Multi-Agent với Conservative Q-Learning")
280
+ st.markdown("---")
281
+
282
+ # Sidebar
283
+ with st.sidebar:
284
+ st.header(config.SIDEBAR_TITLE)
285
+
286
+ # New chat button
287
+ if st.button("🆕 Cuộc trò chuyện mới", use_container_width=True):
288
+ st.session_state.messages = []
289
+ st.session_state.action_history = []
290
+ st.session_state.chatbot.clear_history()
291
+ # Create new session in memory
292
+ st.session_state.memory_manager.create_new_session()
293
+ st.success("✅ Đã tạo cuộc trò chuyện mới!")
294
+ st.rerun()
295
+
296
+ st.divider()
297
+
298
+ # Session history
299
+ st.subheader("💾 Lịch sử hội thoại")
300
+ sessions = st.session_state.memory_manager.get_all_sessions()
301
+
302
+ if sessions:
303
+ st.caption(f"Tổng cộng: {len(sessions)} phiên")
304
+
305
+ # Show last 5 sessions
306
+ for session in sessions[:5]:
307
+ session_id = session['session_id']
308
+ created_at = datetime.fromisoformat(session['created_at']).strftime("%d/%m/%Y %H:%M")
309
+ msg_count = session['message_count']
310
+
311
+ col1, col2 = st.columns([3, 1])
312
+ with col1:
313
+ st.caption(f"📝 {created_at} ({msg_count} tin nhắn)")
314
+ with col2:
315
+ if st.button("🗑️", key=f"del_{session_id}"):
316
+ st.session_state.memory_manager.delete_session(session_id)
317
+ st.rerun()
318
+ else:
319
+ st.info("Chưa có lịch sử")
320
+
321
+ st.divider()
322
+
323
+ # Agent status
324
+ st.subheader("🤖 Trạng thái Agent")
325
+ st.success("✅ CQL Agent: Hoạt động")
326
+ st.success("✅ Communication Agent: Hoạt động")
327
+ st.success("✅ Drawing Agent: Hoạt động")
328
+
329
+ st.divider()
330
+
331
+ # Statistics
332
+ st.subheader("📊 Thống kê")
333
+ st.metric("Số tin nhắn", len(st.session_state.messages))
334
+
335
+ # Action distribution
336
+ if st.session_state.action_history:
337
+ distribution = st.session_state.chatbot.get_action_distribution()
338
+ fig = create_action_distribution_chart(distribution, chart_id=f"dist_{len(st.session_state.action_history)}")
339
+ if fig:
340
+ st.plotly_chart(fig, use_container_width=True, key=f"action_dist_{len(st.session_state.action_history)}")
341
+
342
+ st.divider()
343
+
344
+ # Info
345
+ st.subheader("ℹ️ Thông tin")
346
+ st.info(f"""
347
+ **Model**: CQL Agent
348
+ **State Dim**: {config.STATE_DIM}
349
+ **Actions**: {config.ACTION_DIM}
350
+ **Device**: {config.DEVICE}
351
+ """)
352
+
353
+ # Main chat area
354
+ chat_container = st.container()
355
+
356
+ # Display chat messages
357
+ with chat_container:
358
+ for idx, message in enumerate(st.session_state.messages):
359
+ with st.chat_message(message["role"]):
360
+ st.markdown(message["content"])
361
+
362
+ # Display action badge for assistant messages
363
+ if message["role"] == "assistant" and "action" in message:
364
+ action = message["action"]
365
+ action_name = message["action_name"]
366
+ st.markdown(
367
+ display_action_badge(action, action_name),
368
+ unsafe_allow_html=True
369
+ )
370
+
371
+ # Display Q-values chart
372
+ if "q_values" in message:
373
+ with st.expander("📊 Xem Q-Values"):
374
+ fig = create_q_values_chart(message["q_values"], chart_id=f"qval_{idx}")
375
+ st.plotly_chart(fig, use_container_width=True, key=f"q_values_{idx}")
376
+
377
+ # Display image if available
378
+ if message.get("image_path"):
379
+ st.image(message["image_path"])
380
+
381
+ # Chat input
382
+ if prompt := st.chat_input("💬 Nhập tin nhắn của bạn..."):
383
+ # Add user message
384
+ st.session_state.messages.append({"role": "user", "content": prompt})
385
+
386
+ # Display user message
387
+ with st.chat_message("user"):
388
+ st.markdown(prompt)
389
+
390
+ # Generate response với fixed temperature
391
+ with st.chat_message("assistant"):
392
+ with st.spinner("🤔 Đang suy nghĩ..."):
393
+ response_data = st.session_state.chatbot.chat(prompt, FIXED_TEMPERATURE)
394
+
395
+ # Display response
396
+ st.markdown(response_data['response'])
397
+
398
+ # Display action badge
399
+ st.markdown(
400
+ display_action_badge(response_data['action'], response_data['action_name']),
401
+ unsafe_allow_html=True
402
+ )
403
+
404
+ # Display Q-values
405
+ with st.expander("📊 Xem Q-Values"):
406
+ msg_idx = len(st.session_state.messages)
407
+ fig = create_q_values_chart(response_data['q_values'], chart_id=f"qval_new_{msg_idx}")
408
+ st.plotly_chart(fig, use_container_width=True, key=f"q_values_new_{msg_idx}")
409
+
410
+ # Display image if available
411
+ if response_data.get('image_path'):
412
+ st.image(response_data['image_path'])
413
+
414
+ # Add assistant message to history
415
+ st.session_state.messages.append({
416
+ "role": "assistant",
417
+ "content": response_data['response'],
418
+ "action": response_data['action'],
419
+ "action_name": response_data['action_name'],
420
+ "q_values": response_data['q_values'],
421
+ "image_path": response_data.get('image_path')
422
+ })
423
+
424
+ # Update action history
425
+ st.session_state.action_history.append(response_data['action'])
426
+
427
+ # Rerun to update UI
428
+ st.rerun()
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()
chatbot_engine.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CQL Chatbot Engine - Main chatbot logic integrating CQL agent with multi-agent system
3
+ """
4
+ import sys
5
+ import os
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'Conservative Q-learning'))
7
+
8
+ import torch
9
+ import numpy as np
10
+ from transformers import T5Tokenizer, T5EncoderModel
11
+ from typing import Dict, List, Tuple
12
+ import config
13
+ from memory_manager import MemoryManager
14
+ from cql_agent import CQLAgent
15
+ from communication_agent import CommunicationAgent
16
+ from drawing_agent import DrawingAgent
17
+
18
+
19
+ class CQLChatbot:
20
+ def __init__(self, model_path: str = None, memory_manager: MemoryManager = None):
21
+ """
22
+ Initialize CQL Chatbot with all components
23
+
24
+ Args:
25
+ model_path: Path to saved CQL model
26
+ memory_manager: Memory manager instance for conversation storage
27
+ """
28
+ print("🚀 Initializing CQL Chatbot System...")
29
+
30
+ # Set device
31
+ self.device = torch.device(config.DEVICE if torch.cuda.is_available() else 'cpu')
32
+ print(f"📱 Using device: {self.device}")
33
+
34
+ # Load T5 Encoder for text embedding
35
+ print("📚 Loading T5 encoder...")
36
+ self.tokenizer = T5Tokenizer.from_pretrained(config.T5_MODEL_NAME)
37
+ self.encoder = T5EncoderModel.from_pretrained(config.T5_MODEL_NAME).to(self.device)
38
+ self.encoder.eval() # Set to evaluation mode
39
+ print("✅ T5 encoder loaded")
40
+
41
+ # Load CQL Agent (Decision Maker)
42
+ print("🧠 Loading CQL agent...")
43
+ self.cql_agent = CQLAgent(
44
+ state_dim=config.STATE_DIM,
45
+ action_dim=config.ACTION_DIM,
46
+ is_continuous=False,
47
+ device=self.device
48
+ )
49
+
50
+ # Load trained model
51
+ model_path = model_path or str(config.MODEL_PATH)
52
+ self.cql_agent.load_model(model_path)
53
+ print("✅ CQL agent loaded")
54
+
55
+ # Initialize sub-agents
56
+ print("👥 Initializing sub-agents...")
57
+ self.communication_agent = CommunicationAgent()
58
+ self.drawing_agent = DrawingAgent()
59
+ print("✅ All agents initialized")
60
+
61
+ # Memory manager
62
+ self.memory_manager = memory_manager
63
+ if self.memory_manager:
64
+ print("💾 Memory manager enabled")
65
+
66
+ # Conversation history
67
+ self.conversation_history = []
68
+
69
+ print("🎉 CQL Chatbot System ready!\n")
70
+
71
+ def encode_text(self, text: str) -> np.ndarray:
72
+ """
73
+ Encode text into T5 embedding
74
+
75
+ Args:
76
+ text: Input text
77
+
78
+ Returns:
79
+ Embedding vector (768-dim)
80
+ """
81
+ inputs = self.tokenizer(
82
+ text,
83
+ return_tensors="pt",
84
+ padding=True,
85
+ truncation=True,
86
+ max_length=512
87
+ ).to(self.device)
88
+
89
+ with torch.no_grad():
90
+ outputs = self.encoder(**inputs)
91
+ # Use mean pooling over sequence
92
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten()
93
+
94
+ return embedding
95
+
96
+ def get_action(self, text: str) -> Tuple[int, np.ndarray]:
97
+ """
98
+ Get CQL agent's decision for the input text
99
+
100
+ Args:
101
+ text: User input text
102
+
103
+ Returns:
104
+ Tuple of (action_index, q_values)
105
+ """
106
+ # Encode text to embedding
107
+ embedding = self.encode_text(text)
108
+
109
+ # Get action from CQL agent
110
+ action = self.cql_agent.select_action(embedding, evaluate=True)
111
+
112
+ # Get Q-values for all actions (for visualization)
113
+ state = self.cql_agent.normalizer.normalize(embedding)
114
+ state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
115
+
116
+ with torch.no_grad():
117
+ q_values = self.cql_agent.critic_1(state_tensor).cpu().numpy().flatten()
118
+
119
+ return action, q_values
120
+
121
+ def chat(self, user_message: str, temperature: float = 0.7) -> Dict:
122
+ """
123
+ Main chat function - processes user message and generates response
124
+
125
+ Args:
126
+ user_message: User's input message
127
+ temperature: Response creativity
128
+
129
+ Returns:
130
+ Dictionary containing response and metadata
131
+ """
132
+ # Get CQL agent's decision
133
+ action, q_values = self.get_action(user_message)
134
+ action_name = config.ACTION_MAPPING[action]
135
+
136
+ print(f"\n🤖 CQL Decision: {action_name} (Action {action})")
137
+ print(f"📊 Q-values: {q_values}")
138
+
139
+ # Initialize variables
140
+ response_text = ""
141
+ image_path = None
142
+
143
+ # IMPROVED LOGIC: Check for drawing keywords FIRST
144
+ # Override CQL decision if drawing keywords detected
145
+ drawing_keywords = ['vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh', 'draw', 'paint', 'create image', 'generate']
146
+ is_drawing_request = any(keyword in user_message.lower() for keyword in drawing_keywords)
147
+
148
+ # Force Drawing Agent if keywords detected
149
+ if is_drawing_request:
150
+ print("🎨 Drawing keywords detected! Forcing Drawing Agent.")
151
+ action = 1
152
+ action_name = config.ACTION_MAPPING[1]
153
+
154
+ # Execute based on final action
155
+ if action == 0: # Communication Agent
156
+ response_text = self.communication_agent.generate_response(
157
+ user_message,
158
+ self.conversation_history,
159
+ temperature
160
+ )
161
+
162
+ elif action == 1: # Drawing Agent
163
+ response_text, image_path = self.drawing_agent.generate_sketch(user_message)
164
+
165
+ elif action == 2: # Clarification - fallback to Communication
166
+ print("⚠️ CQL suggested Clarification. Using Communication Agent.")
167
+ response_text = self.communication_agent.generate_response(
168
+ user_message,
169
+ self.conversation_history,
170
+ temperature
171
+ )
172
+ action = 0
173
+ action_name = config.ACTION_MAPPING[0]
174
+
175
+ # Update conversation history
176
+ self.conversation_history.append({
177
+ 'role': 'user',
178
+ 'content': user_message
179
+ })
180
+ self.conversation_history.append({
181
+ 'role': 'assistant',
182
+ 'content': response_text,
183
+ 'action': action,
184
+ 'action_name': action_name
185
+ })
186
+
187
+ # Limit history length
188
+ if len(self.conversation_history) > config.MAX_HISTORY_LENGTH:
189
+ self.conversation_history = self.conversation_history[-config.MAX_HISTORY_LENGTH:]
190
+
191
+ # Save to memory manager if available
192
+ if self.memory_manager:
193
+ self.memory_manager.save_message('user', user_message)
194
+ self.memory_manager.save_message(
195
+ 'assistant',
196
+ response_text,
197
+ {
198
+ 'action': action,
199
+ 'action_name': action_name,
200
+ 'q_values': q_values.tolist()
201
+ }
202
+ )
203
+
204
+ return {
205
+ 'response': response_text,
206
+ 'action': action,
207
+ 'action_name': action_name,
208
+ 'q_values': q_values.tolist(),
209
+ 'image_path': image_path
210
+ }
211
+
212
+ def _generate_clarification_request(self, user_message: str) -> str:
213
+ """Generate a clarification request when input is unclear"""
214
+ clarifications = [
215
+ f"Xin lỗi, tôi chưa hiểu rõ yêu cầu của bạn: '{user_message}'. Bạn có thể nói rõ hơn được không?",
216
+ f"Tôi cần thêm thông tin để hiểu câu hỏi của bạn. Bạn muốn tôi làm gì với: '{user_message}'?",
217
+ f"Câu hỏi của bạn chưa rõ ràng. Bạn có thể diễn đạt lại không?",
218
+ f"Hmm, tôi không chắc bạn đang hỏi gì. Bạn có thể cung cấp thêm chi tiết không?"
219
+ ]
220
+
221
+ import random
222
+ return random.choice(clarifications)
223
+
224
+ def clear_history(self):
225
+ """Clear conversation history"""
226
+ self.conversation_history = []
227
+ print("🗑️ Conversation history cleared")
228
+
229
+ def get_action_distribution(self) -> Dict[str, int]:
230
+ """Get distribution of actions taken in current conversation"""
231
+ distribution = {name: 0 for name in config.ACTION_MAPPING.values()}
232
+
233
+ for msg in self.conversation_history:
234
+ if msg.get('role') == 'assistant' and 'action_name' in msg:
235
+ action_name = msg['action_name']
236
+ distribution[action_name] = distribution.get(action_name, 0) + 1
237
+
238
+ return distribution
communication_agent.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Communication Agent - Gemini API Version (Stable)
3
+ Backup GPT-2 version available in communication_agent_gpt2.py
4
+ """
5
+ import google.generativeai as genai
6
+ from typing import List, Dict
7
+ import os
8
+ import config
9
+
10
+ class CommunicationAgent:
11
+ def __init__(self, api_key: str = None):
12
+ """Initialize Communication Agent with Gemini API"""
13
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY", "")
14
+
15
+ if self.api_key:
16
+ genai.configure(api_key=self.api_key)
17
+ self.model = genai.GenerativeModel(config.GEMINI_MODEL)
18
+ self.enabled = True
19
+ print("✅ Communication Agent (Gemini) ready!")
20
+ else:
21
+ self.model = None
22
+ self.enabled = False
23
+ print("⚠️ Warning: GEMINI_API_KEY not set. Please add it to .env file")
24
+
25
+ # System context - Clear instructions for better responses
26
+ self.system_context = (
27
+ "You are a helpful, friendly AI assistant. "
28
+ "Respond naturally and conversationally. "
29
+ "Keep responses concise (2-3 sentences). "
30
+ "Be warm and engaging. "
31
+ "If you don't understand, ask for clarification politely."
32
+ )
33
+
34
+ def generate_response(
35
+ self,
36
+ user_message: str,
37
+ conversation_history: List[Dict] = None,
38
+ temperature: float = 0.7
39
+ ) -> str:
40
+ """
41
+ Generate a conversational response using Gemini
42
+
43
+ Args:
44
+ user_message: User's input message
45
+ conversation_history: Previous conversation context
46
+ temperature: Response creativity (0.0-1.0)
47
+
48
+ Returns:
49
+ Generated response text
50
+ """
51
+ if not self.enabled:
52
+ return "⚠️ Gemini API not configured. Please add GEMINI_API_KEY to .env file."
53
+
54
+ try:
55
+ # Build context from history
56
+ context = self._build_context(conversation_history)
57
+
58
+ # Create prompt with system context
59
+ prompt = f"""{self.system_context}
60
+
61
+ {context}
62
+
63
+ User: {user_message}
64
+ Assistant:"""
65
+
66
+ # Generate response with Gemini
67
+ generation_config = {
68
+ 'temperature': temperature,
69
+ 'top_p': 0.95,
70
+ 'top_k': 40,
71
+ 'max_output_tokens': 200,
72
+ }
73
+
74
+ response = self.model.generate_content(
75
+ prompt,
76
+ generation_config=generation_config
77
+ )
78
+
79
+ return response.text.strip()
80
+
81
+ except Exception as e:
82
+ return f"Sorry, an error occurred: {str(e)}"
83
+
84
+ def _build_context(self, conversation_history: List[Dict] = None) -> str:
85
+ """Build conversation context from history"""
86
+ if not conversation_history:
87
+ return ""
88
+
89
+ context = "Previous conversation:\n"
90
+ # Use last 5 messages for context
91
+ for msg in conversation_history[-5:]:
92
+ role = "User" if msg.get('role') == 'user' else "Assistant"
93
+ content = msg.get('content', '')
94
+ context += f"{role}: {content}\n"
95
+
96
+ return context
config.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file for CQL Chatbot System
3
+ """
4
+ import os
5
+ from pathlib import Path
6
+ from dotenv import load_dotenv
7
+
8
+ # Load environment variables
9
+ load_dotenv()
10
+
11
+ # Project paths
12
+ PROJECT_ROOT = Path(__file__).parent
13
+ MODEL_PATH = PROJECT_ROOT / "Conservative Q-learning" / "saved_agent_1"
14
+
15
+ # Model configuration
16
+ STATE_DIM = 768 # T5-base embedding dimension
17
+ ACTION_DIM = 3 # 0: Chat, 1: Sketch, 2: Clarify
18
+ DEVICE = os.getenv("DEVICE", "cuda")
19
+
20
+ # Action mapping
21
+ ACTION_MAPPING = {
22
+ 0: "Communication Agent",
23
+ 1: "Drawing Agent",
24
+ 2: "Clarification"
25
+ }
26
+
27
+ ACTION_DESCRIPTIONS = {
28
+ 0: "💬 Trò chuyện thông thường",
29
+ 1: "🎨 Vẽ sketch/hình ảnh",
30
+ 2: "❓ Cần làm rõ thêm"
31
+ }
32
+
33
+ # T5 Model
34
+ T5_MODEL_NAME = "t5-base"
35
+
36
+ # Gemini API (Primary - Stable)
37
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
38
+ GEMINI_MODEL = "gemini-2.5-flash"
39
+
40
+ # GPT-2 Model (Backup - Local, available in communication_agent_gpt2.py)
41
+ GPT2_MODEL_NAME = "gpt2" # Can use "gpt2-medium" or "gpt2-large" for better quality
42
+
43
+ # Streamlit UI Configuration
44
+ APP_TITLE = "🤖 Chatbot Wisdom - CQL Multi-Agent System"
45
+ APP_ICON = "🧠"
46
+ SIDEBAR_TITLE = "⚙️ Cài đặt"
47
+
48
+ # Chat settings
49
+ MAX_HISTORY_LENGTH = 50
50
+ DEFAULT_TEMPERATURE = 0.7
51
+ DEFAULT_MAX_TOKENS = 512
drawing_agent.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Drawing Agent - Integrated with Stable Diffusion Model
3
+ Uses pre-trained Diffusion model from Agent_Diffusion folder
4
+ """
5
+ import sys
6
+ import os
7
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'Agent_Diffusion'))
8
+
9
+ import torch
10
+ from typing import Optional, Tuple
11
+ from pathlib import Path
12
+ import random
13
+ from PIL import Image
14
+ from diffusers import StableDiffusionPipeline
15
+ from stable_diffusion import MiniDiffusionPipeline
16
+
17
+
18
+ class DrawingAgent:
19
+ def __init__(self):
20
+ """Initialize Drawing Agent with Diffusion Model"""
21
+ self.enabled = True
22
+ self.output_dir = Path("generated_images")
23
+ self.output_dir.mkdir(exist_ok=True)
24
+
25
+ # Model paths
26
+ self.base_model_id = "runwayml/stable-diffusion-v1-5"
27
+ self.unet_path = Path("Agent_Diffusion/unet-mini.safetensors")
28
+ self.vae_path = Path("Agent_Diffusion/vae-finetuned.safetensors")
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ # Model configuration
32
+ self.tiny_unet_config = {
33
+ "unet_block_out_channels": (128, 256, 512),
34
+ }
35
+
36
+ # Lazy loading - only load when first needed
37
+ self.pipeline = None
38
+ self.model_loaded = False
39
+
40
+ print("✅ Drawing Agent initialized (Diffusion model will load on first use)")
41
+
42
+ def _load_model(self):
43
+ """Load Diffusion model (lazy loading)"""
44
+ if self.model_loaded:
45
+ return
46
+
47
+ try:
48
+ print("🎨 Loading Diffusion model...")
49
+
50
+ # Initialize MiniDiffusionPipeline
51
+ container = MiniDiffusionPipeline(
52
+ base_model_id=self.base_model_id,
53
+ device=self.device,
54
+ config_overrides=self.tiny_unet_config
55
+ )
56
+
57
+ # Load UNet weights
58
+ print(f"Loading UNet from {self.unet_path}...")
59
+ unet_weights = torch.load(str(self.unet_path), map_location=self.device)
60
+ container.unet.load_state_dict(unet_weights)
61
+
62
+ # Load VAE weights
63
+ print(f"Loading VAE from {self.vae_path}...")
64
+ vae_weights = torch.load(str(self.vae_path), map_location=self.device)
65
+ container.vae.load_state_dict(vae_weights)
66
+
67
+ # Create inference pipeline
68
+ torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
69
+
70
+ self.pipeline = StableDiffusionPipeline(
71
+ unet=container.unet,
72
+ vae=container.vae,
73
+ text_encoder=container.text_encoder,
74
+ tokenizer=container.tokenizer,
75
+ scheduler=container.noise_scheduler,
76
+ safety_checker=None,
77
+ feature_extractor=None,
78
+ ).to(self.device)
79
+
80
+ if self.device == "cuda":
81
+ self.pipeline.to(dtype=torch_dtype)
82
+
83
+ self.pipeline.set_progress_bar_config(disable=True)
84
+ self.model_loaded = True
85
+
86
+ print("✅ Diffusion model loaded successfully!")
87
+
88
+ except Exception as e:
89
+ print(f"❌ Error loading Diffusion model: {e}")
90
+ self.enabled = False
91
+
92
+ def generate_sketch(self, prompt: str) -> Tuple[str, Optional[str]]:
93
+ """
94
+ Generate a sketch based on the prompt using Diffusion model
95
+
96
+ Args:
97
+ prompt: Description of what to draw
98
+
99
+ Returns:
100
+ Tuple of (response_text, image_path)
101
+ """
102
+ try:
103
+ # Load model if not loaded
104
+ if not self.model_loaded:
105
+ self._load_model()
106
+
107
+ if not self.enabled or self.pipeline is None:
108
+ return self._fallback_response(prompt)
109
+
110
+ # Parse the sketch request
111
+ clean_prompt = self.parse_sketch_request(prompt)
112
+
113
+ print(f"🎨 Generating image for: '{clean_prompt}'")
114
+
115
+ # Generate image
116
+ current_seed = random.randint(0, 2**32 - 1)
117
+ generator = torch.Generator(device=self.device).manual_seed(current_seed)
118
+
119
+ with torch.no_grad():
120
+ image = self.pipeline(
121
+ prompt=clean_prompt,
122
+ num_inference_steps=50,
123
+ generator=generator,
124
+ guidance_scale=7.5
125
+ ).images[0]
126
+
127
+ # Save image
128
+ timestamp = Path(f"{random.randint(1000, 9999)}.png")
129
+ image_path = self.output_dir / timestamp
130
+ image.save(str(image_path))
131
+
132
+ print(f"✅ Image saved to: {image_path}")
133
+
134
+ # Response text
135
+ response_text = f"🎨 **Image Generated!**\\n\\n"
136
+ response_text += f"Prompt: **{clean_prompt}**\\n"
137
+ response_text += f"Seed: {current_seed}\\n\\n"
138
+ response_text += "Here's your generated image:"
139
+
140
+ return response_text, str(image_path)
141
+
142
+ except Exception as e:
143
+ error_msg = f"Sorry, I encountered an error while generating the image: {str(e)}"
144
+ print(f"❌ Drawing error: {e}")
145
+ return error_msg, None
146
+
147
+ def _fallback_response(self, prompt: str) -> Tuple[str, None]:
148
+ """Fallback response when model can't load"""
149
+ clean_prompt = self.parse_sketch_request(prompt)
150
+
151
+ response_text = f"🎨 **Drawing Request Received**\\n\\n"
152
+ response_text += f"I understand you want me to draw: **{clean_prompt}**\\n\\n"
153
+ response_text += "⚠️ **Note**: Diffusion model failed to load. "
154
+ response_text += "Please check that the model files exist in Agent_Diffusion folder."
155
+
156
+ return response_text, None
157
+
158
+ def parse_sketch_request(self, user_message: str) -> str:
159
+ """
160
+ Parse the sketch request to extract key details
161
+
162
+ Args:
163
+ user_message: User's request message
164
+
165
+ Returns:
166
+ Cleaned prompt for image generation
167
+ """
168
+ # Remove common drawing keywords to get the core subject
169
+ drawing_keywords = [
170
+ 'vẽ', 'sketch', 'phác thảo', 'hình', 'ảnh', 'tranh',
171
+ 'draw', 'paint', 'create', 'make', 'generate',
172
+ 'cho tôi', 'cho mình', 'giúp tôi', 'help me', 'for me',
173
+ 'một', 'a', 'an', 'the'
174
+ ]
175
+
176
+ prompt = user_message.lower()
177
+
178
+ # Remove keywords
179
+ for keyword in drawing_keywords:
180
+ prompt = prompt.replace(keyword, '')
181
+
182
+ # Clean up
183
+ prompt = ' '.join(prompt.split()) # Remove extra spaces
184
+ prompt = prompt.strip()
185
+
186
+ # If empty after cleaning, use original
187
+ if not prompt or len(prompt) < 3:
188
+ prompt = user_message
189
+
190
+ return prompt
191
+
192
+ def is_drawing_request(self, user_message: str) -> bool:
193
+ """
194
+ Check if the message is a drawing request
195
+
196
+ Args:
197
+ user_message: User's message
198
+
199
+ Returns:
200
+ True if it's a drawing request
201
+ """
202
+ drawing_keywords = [
203
+ 'vẽ', 'sketch', 'phác thảo', 'draw', 'paint',
204
+ 'create image', 'generate image', 'make picture'
205
+ ]
206
+
207
+ message_lower = user_message.lower()
208
+ return any(keyword in message_lower for keyword in drawing_keywords)
memory_manager.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory Manager - Lưu trữ và quản lý lịch sử hội thoại
3
+ """
4
+ import json
5
+ import os
6
+ from datetime import datetime
7
+ from typing import List, Dict, Optional
8
+ from pathlib import Path
9
+
10
+
11
+ class MemoryManager:
12
+ def __init__(self, storage_dir: str = "conversation_history"):
13
+ """
14
+ Initialize Memory Manager
15
+
16
+ Args:
17
+ storage_dir: Thư mục lưu trữ lịch sử hội thoại
18
+ """
19
+ self.storage_dir = Path(storage_dir)
20
+ self.storage_dir.mkdir(exist_ok=True)
21
+ self.current_session_file = None
22
+
23
+ def create_new_session(self) -> str:
24
+ """
25
+ Tạo session mới
26
+
27
+ Returns:
28
+ Session ID
29
+ """
30
+ session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
31
+ self.current_session_file = self.storage_dir / f"session_{session_id}.json"
32
+
33
+ # Tạo file session mới
34
+ session_data = {
35
+ "session_id": session_id,
36
+ "created_at": datetime.now().isoformat(),
37
+ "messages": []
38
+ }
39
+
40
+ self._save_session(session_data)
41
+ return session_id
42
+
43
+ def save_message(self, role: str, content: str, metadata: Dict = None):
44
+ """
45
+ Lưu tin nhắn vào session hiện tại
46
+
47
+ Args:
48
+ role: 'user' hoặc 'assistant'
49
+ content: Nội dung tin nhắn
50
+ metadata: Thông tin bổ sung (action, q_values, etc.)
51
+ """
52
+ if not self.current_session_file:
53
+ self.create_new_session()
54
+
55
+ session_data = self._load_session()
56
+
57
+ message = {
58
+ "role": role,
59
+ "content": content,
60
+ "timestamp": datetime.now().isoformat()
61
+ }
62
+
63
+ if metadata:
64
+ message.update(metadata)
65
+
66
+ session_data["messages"].append(message)
67
+ self._save_session(session_data)
68
+
69
+ def load_session(self, session_id: str) -> Optional[Dict]:
70
+ """
71
+ Load session theo ID
72
+
73
+ Args:
74
+ session_id: ID của session cần load
75
+
76
+ Returns:
77
+ Session data hoặc None nếu không tìm thấy
78
+ """
79
+ session_file = self.storage_dir / f"session_{session_id}.json"
80
+
81
+ if not session_file.exists():
82
+ return None
83
+
84
+ with open(session_file, 'r', encoding='utf-8') as f:
85
+ return json.load(f)
86
+
87
+ def get_all_sessions(self) -> List[Dict]:
88
+ """
89
+ Lấy danh sách tất cả sessions
90
+
91
+ Returns:
92
+ List các session info
93
+ """
94
+ sessions = []
95
+
96
+ for file in self.storage_dir.glob("session_*.json"):
97
+ try:
98
+ with open(file, 'r', encoding='utf-8') as f:
99
+ data = json.load(f)
100
+ sessions.append({
101
+ "session_id": data["session_id"],
102
+ "created_at": data["created_at"],
103
+ "message_count": len(data["messages"]),
104
+ "file": str(file)
105
+ })
106
+ except Exception as e:
107
+ print(f"Error loading session {file}: {e}")
108
+
109
+ # Sắp xếp theo thời gian tạo (mới nhất trước)
110
+ sessions.sort(key=lambda x: x["created_at"], reverse=True)
111
+ return sessions
112
+
113
+ def delete_session(self, session_id: str) -> bool:
114
+ """
115
+ Xóa session
116
+
117
+ Args:
118
+ session_id: ID của session cần xóa
119
+
120
+ Returns:
121
+ True nếu xóa thành công
122
+ """
123
+ session_file = self.storage_dir / f"session_{session_id}.json"
124
+
125
+ if session_file.exists():
126
+ session_file.unlink()
127
+ return True
128
+
129
+ return False
130
+
131
+ def _load_session(self) -> Dict:
132
+ """Load session hiện tại"""
133
+ if not self.current_session_file or not self.current_session_file.exists():
134
+ return {
135
+ "session_id": "default",
136
+ "created_at": datetime.now().isoformat(),
137
+ "messages": []
138
+ }
139
+
140
+ with open(self.current_session_file, 'r', encoding='utf-8') as f:
141
+ return json.load(f)
142
+
143
+ def _save_session(self, session_data: Dict):
144
+ """Lưu session data"""
145
+ if not self.current_session_file:
146
+ return
147
+
148
+ with open(self.current_session_file, 'w', encoding='utf-8') as f:
149
+ json.dump(session_data, f, ensure_ascii=False, indent=2)
150
+
151
+ def get_current_messages(self) -> List[Dict]:
152
+ """Lấy tất cả messages của session hiện tại"""
153
+ session_data = self._load_session()
154
+ return session_data.get("messages", [])
155
+
156
+ def export_session(self, session_id: str, output_file: str):
157
+ """
158
+ Export session ra file
159
+
160
+ Args:
161
+ session_id: ID của session
162
+ output_file: Đường dẫn file output
163
+ """
164
+ session_data = self.load_session(session_id)
165
+
166
+ if not session_data:
167
+ return False
168
+
169
+ with open(output_file, 'w', encoding='utf-8') as f:
170
+ json.dump(session_data, f, ensure_ascii=False, indent=2)
171
+
172
+ return True
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.31.0
2
+ torch>=2.0.0
3
+ transformers>=4.36.0
4
+ numpy>=1.24.0
5
+ pillow>=10.0.0
6
+ google-generativeai>=0.3.0
7
+ python-dotenv>=1.0.0
8
+ plotly>=5.18.0
9
+ sentencepiece>=0.1.99
10
+ diffusers>=0.25.0
11
+ accelerate>=0.25.0
12
+ safetensors>=0.4.0