Larer commited on
Commit
5c19a88
·
1 Parent(s): 65e04f2

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusion-dpo-ocr/check_video_resolution.py +194 -0
  2. diffusion-dpo-ocr/prepare_roadtext.py +625 -0
  3. diffusion-dpo-ocr/results/roadtext_eval_results_BSRGAN.json +17 -0
  4. diffusion-dpo-ocr/results/roadtext_eval_results_DP2O-SR.json +17 -0
  5. diffusion-dpo-ocr/results/roadtext_eval_results_DiT4SR.json +17 -0
  6. diffusion-dpo-ocr/results/roadtext_eval_results_DiffBIR.json +17 -0
  7. diffusion-dpo-ocr/results/roadtext_eval_results_FaithDiff.json +17 -0
  8. diffusion-dpo-ocr/results/roadtext_eval_results_Ours.json +17 -0
  9. diffusion-dpo-ocr/results/roadtext_eval_results_Real-ESRGAN.json +17 -0
  10. diffusion-dpo-ocr/results/roadtext_eval_results_SUPSR.json +17 -0
  11. diffusion-dpo-ocr/results/roadtext_eval_results_SeeSR.json +17 -0
  12. diffusion-dpo-ocr/results/roadtext_eval_results_StableSR.json +17 -0
  13. diffusion-dpo-ocr/results/roadtext_eval_results_SwinIR.json +17 -0
  14. diffusion-dpo-ocr/results/roadtext_eval_results_gt.json +17 -0
  15. diffusion-dpo-ocr/results/roadtext_eval_results_sample00.json +17 -0
  16. diffusion-dpo-ocr/results/roadtext_eval_results_zoomlr.json +17 -0
  17. diffusion-dpo-ocr/roadtext_eval_results_output.json +17 -0
  18. diffusion-dpo-ocr/test_roadtext.py +514 -0
  19. diffusion-dpo-ocr/verify_roadtext_annotations.py +223 -0
  20. diffusion-dpo-test/DIAGNOSTIC_CHECKLIST.md +297 -0
  21. diffusion-dpo-test/DIV2K-val/sobolev-400/0000843-seed-0.png +0 -0
  22. diffusion-dpo-test/__pycache__/color_fix.cpython-310.pyc +0 -0
  23. diffusion-dpo-test/analyze_lora_magnitude.py +179 -0
  24. diffusion-dpo-test/check_lora_keys.py +76 -0
  25. diffusion-dpo-test/color_fix.py +119 -0
  26. diffusion-dpo-test/compare.py +73 -0
  27. diffusion-dpo-test/compare_checkpoints.py +147 -0
  28. diffusion-dpo-test/data_val/0000009-seed-0.png +0 -0
  29. diffusion-dpo-test/data_val/0000010-seed-0.png +0 -0
  30. diffusion-dpo-test/fix_lora_keys.py +132 -0
  31. diffusion-dpo-test/inspect_safetensor.py +115 -0
  32. diffusion-dpo-test/metrics.json +142 -0
  33. diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000263-seed-0.png +0 -0
  34. diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000463-seed-0.png +0 -0
  35. diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000563-seed-0.png +0 -0
  36. diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000763-seed-0.png +0 -0
  37. diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000863-seed-0.png +0 -0
  38. diffusion-dpo-test/results-test/DrealSR/sony_160_x4.png +0 -0
  39. diffusion-dpo-test/results-test/DrealSR/sony_189_x4.png +0 -0
  40. diffusion-dpo-test/src/flux/__pycache__/block.cpython-310.pyc +0 -0
  41. diffusion-dpo-test/src/flux/__pycache__/block.cpython-311.pyc +0 -0
  42. diffusion-dpo-test/src/flux/__pycache__/condition.cpython-310.pyc +0 -0
  43. diffusion-dpo-test/src/flux/__pycache__/condition.cpython-311.pyc +0 -0
  44. diffusion-dpo-test/src/flux/__pycache__/generate.cpython-310.pyc +0 -0
  45. diffusion-dpo-test/src/flux/__pycache__/generate.cpython-311.pyc +0 -0
  46. diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-310.pyc +0 -0
  47. diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-311.pyc +0 -0
  48. diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-310.pyc +0 -0
  49. diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-311.pyc +0 -0
  50. diffusion-dpo-test/src/flux/__pycache__/transformer.cpython-310.pyc +0 -0
diffusion-dpo-ocr/check_video_resolution.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 检查 RoadText1K Videos 目录下所有视频的分辨率
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from collections import Counter
8
+ import cv2
9
+ from tqdm import tqdm
10
+
11
+
12
+ # ============================================================================
13
+ # 配置参数 - 请修改这里
14
+ # ============================================================================
15
+ CONFIG = {
16
+ # RoadText1K Videos 目录
17
+ 'videos_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k/Videos',
18
+
19
+ # 目标分辨率 (可选,用于检查是否匹配)
20
+ 'target_resolution': (1920, 1080), # 例如: (1920, 1080) 或 None 只统计
21
+
22
+ # 检查哪个数据集
23
+ 'split': 'test', # 'test', 'train', 'val', 或 'all' 检查全部
24
+ }
25
+ # ============================================================================
26
+
27
+
28
+ def get_video_resolution(video_path: Path) -> tuple:
29
+ """
30
+ 获取视频分辨率
31
+
32
+ Returns:
33
+ (width, height) 或 None 如果无法读取
34
+ """
35
+ cap = cv2.VideoCapture(str(video_path))
36
+ if not cap.isOpened():
37
+ return None
38
+
39
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
40
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
41
+ fps = cap.get(cv2.CAP_PROP_FPS)
42
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
43
+
44
+ cap.release()
45
+
46
+ return {
47
+ 'width': width,
48
+ 'height': height,
49
+ 'fps': fps,
50
+ 'frame_count': frame_count,
51
+ }
52
+
53
+
54
+ def main():
55
+ videos_dir = Path(CONFIG['videos_dir'])
56
+ split = CONFIG['split']
57
+ target_res = CONFIG['target_resolution']
58
+
59
+ print("="*60)
60
+ print("Video Resolution Checker")
61
+ print("="*60)
62
+ print(f"Videos directory: {videos_dir}")
63
+ print(f"Split: {split}")
64
+ if target_res:
65
+ print(f"Target resolution: {target_res[0]}x{target_res[1]}")
66
+ print()
67
+
68
+ # 确定要检查的目录
69
+ if split == 'all':
70
+ check_dirs = ['test', 'train', 'val']
71
+ else:
72
+ check_dirs = [split]
73
+
74
+ all_resolutions = []
75
+ resolution_stats = Counter()
76
+ mismatched_videos = []
77
+
78
+ total_videos = 0
79
+
80
+ for split_name in check_dirs:
81
+ split_dir = videos_dir / split_name
82
+
83
+ if not split_dir.exists():
84
+ print(f"Warning: Directory not found: {split_dir}")
85
+ continue
86
+
87
+ print(f"\nChecking {split_name}...")
88
+
89
+ # 查找所有子目录中的视频文件(支持嵌套结构)
90
+ video_files = []
91
+ for subdir in sorted(split_dir.iterdir()):
92
+ if subdir.is_dir():
93
+ # 递归查找子目录中的视频
94
+ video_files.extend(subdir.glob('*.mp4'))
95
+ video_files.extend(subdir.glob('*.avi'))
96
+ elif subdir.suffix in ['.mp4', '.avi']:
97
+ # 视频文件直接在 split 目录下
98
+ video_files.append(subdir)
99
+
100
+ video_files = sorted(video_files)
101
+
102
+ print(f"Found {len(video_files)} videos")
103
+
104
+ for video_path in tqdm(video_files, desc=f"Processing {split_name}"):
105
+ info = get_video_resolution(video_path)
106
+
107
+ if info is None:
108
+ print(f"Warning: Cannot read {video_path.name}")
109
+ continue
110
+
111
+ width = info['width']
112
+ height = info['height']
113
+ resolution = (width, height)
114
+
115
+ all_resolutions.append({
116
+ 'file': video_path.name,
117
+ 'split': split_name,
118
+ 'width': width,
119
+ 'height': height,
120
+ 'fps': info['fps'],
121
+ 'frame_count': info['frame_count'],
122
+ })
123
+
124
+ resolution_stats[resolution] += 1
125
+
126
+ # 检查是否匹配目标分辨率
127
+ if target_res:
128
+ if resolution != target_res:
129
+ mismatched_videos.append({
130
+ 'file': video_path.name,
131
+ 'split': split_name,
132
+ 'resolution': f"{width}x{height}",
133
+ 'expected': f"{target_res[0]}x{target_res[1]}",
134
+ })
135
+
136
+ total_videos += 1
137
+
138
+ # 打印统计结果
139
+ print("\n" + "="*60)
140
+ print("STATISTICS")
141
+ print("="*60)
142
+ print(f"Total videos checked: {total_videos}")
143
+ print(f"Unique resolutions: {len(resolution_stats)}")
144
+ print()
145
+
146
+ print("Resolution distribution:")
147
+ print("-" * 60)
148
+ for (width, height), count in resolution_stats.most_common():
149
+ percentage = count / total_videos * 100
150
+ print(f" {width:4d} x {height:4d}: {count:4d} videos ({percentage:5.1f}%)")
151
+
152
+ # 打印最常见的分辨率
153
+ if resolution_stats:
154
+ most_common = resolution_stats.most_common(1)[0]
155
+ print(f"\nMost common resolution: {most_common[0][0]}x{most_common[0][1]} ({most_common[1]} videos)")
156
+
157
+ # 如果有目标分辨率,显示不匹配的
158
+ if target_res:
159
+ print("\n" + "="*60)
160
+ print("RESOLUTION MATCHING")
161
+ print("="*60)
162
+ if mismatched_videos:
163
+ print(f"Videos NOT matching target resolution: {len(mismatched_videos)}")
164
+ print("\nMismatched videos (first 20):")
165
+ for item in mismatched_videos[:20]:
166
+ print(f" {item['split']}/{item['file']}: {item['resolution']} (expected {item['expected']})")
167
+ if len(mismatched_videos) > 20:
168
+ print(f" ... and {len(mismatched_videos) - 20} more")
169
+ else:
170
+ print(f"✓ All videos match target resolution {target_res[0]}x{target_res[1]}")
171
+
172
+ # 保存详细结果
173
+ output_file = videos_dir.parent / 'video_resolution_report.json'
174
+ import json
175
+ report = {
176
+ 'total_videos': total_videos,
177
+ 'resolution_distribution': {f"{w}x{h}": count for (w, h), count in resolution_stats.items()},
178
+ 'videos': all_resolutions,
179
+ }
180
+
181
+ if target_res:
182
+ report['target_resolution'] = f"{target_res[0]}x{target_res[1]}"
183
+ report['mismatched_count'] = len(mismatched_videos)
184
+ report['mismatched_videos'] = mismatched_videos
185
+
186
+ with open(output_file, 'w') as f:
187
+ json.dump(report, f, indent=2)
188
+
189
+ print(f"\nDetailed report saved to: {output_file}")
190
+
191
+
192
+ if __name__ == '__main__':
193
+ main()
194
+
diffusion-dpo-ocr/prepare_roadtext.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RoadText1K 预处理脚本
3
+ 从视频中提取帧,resize 到 512x512,合并 Localisation 和 Text_Transcription 标注
4
+
5
+ 使用流程:
6
+ 1. 运行此脚本生成 GT images (512x512) 和合并后的标注
7
+ 2. 用你的方式生成 LR 和 SR images
8
+ 3. 运行 test_roadtext.py 评估 OCR 性能
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import random
14
+ from pathlib import Path
15
+ from typing import Dict, List, Tuple
16
+
17
+ from PIL import Image
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+ import cv2
21
+
22
+
23
+ # ============================================================================
24
+ # 配置参数 - 请修改这里
25
+ # ============================================================================
26
+ CONFIG = {
27
+ # RoadText1K 根目录
28
+ 'roadtext_root': '/home/wanghongbo06/baipurui/DATA/RoadText1k',
29
+
30
+ # 使用哪个数据集 (test/train/val)
31
+ 'split': 'test',
32
+
33
+ # 输出目录
34
+ 'output_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop',
35
+
36
+ # Crop 尺寸 (从原图中 crop 这个大小的区域)
37
+ 'crop_size': 512,
38
+
39
+ # Crop 策略: 'center', 'random', 'text_center' (以文本区域为中心)
40
+ 'crop_strategy': 'text_center',
41
+
42
+ # 最小文本框保留比例 (文本框至少要有这么多比例在 crop 区域内才保留)
43
+ # 1.0 表示只保留完全在 crop 区域内的文本框(推荐,避免坐标问题)
44
+ # 0.7 表示文本框至少 70% 在 crop 区域内(但可能导致坐标超出边界)
45
+ 'min_box_overlap': 1.0,
46
+
47
+ # 最小文本框数量 (crop 后至少要有这么多有效文本框)
48
+ 'min_text_boxes': 1,
49
+
50
+ # 最终图片数量限制
51
+ 'max_frames': 1000,
52
+
53
+ # 随机种子
54
+ 'seed': 42,
55
+ }
56
+ # ============================================================================
57
+
58
+
59
+ def load_localisation_annotations(root_dir: str) -> Dict:
60
+ """加载所有 Localisation 标注"""
61
+ loc_dir = Path(root_dir) / 'Ground_truths' / 'Localisation'
62
+ all_annotations = {}
63
+
64
+ print("Loading Localisation annotations...")
65
+ json_files = sorted(loc_dir.glob('*.json'))
66
+ print(f"Found {len(json_files)} JSON files")
67
+
68
+ for json_file in json_files:
69
+ try:
70
+ with open(json_file, 'r', encoding='utf-8') as f:
71
+ data = json.load(f)
72
+ # data 应该是一个列表
73
+ if isinstance(data, list):
74
+ for item in data:
75
+ img_name = item.get('name', '')
76
+ if img_name:
77
+ all_annotations[img_name] = item
78
+ else:
79
+ print(f"Warning: {json_file.name} is not a list format")
80
+ except Exception as e:
81
+ print(f"Error loading {json_file.name}: {e}")
82
+
83
+ print(f"Loaded {len(all_annotations)} image annotations")
84
+ return all_annotations
85
+
86
+
87
+ def load_text_transcriptions(root_dir: str) -> Dict:
88
+ """加载所有 Text_Transcription 标注"""
89
+ text_dir = Path(root_dir) / 'Ground_truths' / 'Text_Transcription'
90
+ all_texts = {}
91
+
92
+ print("Loading Text_Transcription annotations...")
93
+ # 支持 *.json 和 *.json.json (双扩展名)
94
+ json_files = sorted(list(text_dir.glob('*.json')) + list(text_dir.glob('*.json.json')))
95
+ print(f"Found {len(json_files)} JSON files")
96
+
97
+ for json_file in json_files:
98
+ try:
99
+ with open(json_file, 'r', encoding='utf-8') as f:
100
+ data = json.load(f)
101
+ # data 应该是一个字典 {video_name: {label_id: text}}
102
+ if isinstance(data, dict):
103
+ for video_name, texts in data.items():
104
+ if video_name not in all_texts:
105
+ all_texts[video_name] = {}
106
+ if isinstance(texts, dict):
107
+ all_texts[video_name].update(texts)
108
+ else:
109
+ print(f"Warning: {json_file.name} is not a dict format")
110
+ except Exception as e:
111
+ print(f"Error loading {json_file.name}: {e}")
112
+
113
+ print(f"Loaded texts for {len(all_texts)} videos")
114
+ return all_texts
115
+
116
+
117
+ def get_box_bounds(box: List[float]) -> Tuple[float, float, float, float]:
118
+ """从多边形获取边界框 (x1, y1, x2, y2)"""
119
+ # box2d 格式 [x1,y1,x2,y1,x2,y2,x1,y2] 或其他多边形格式
120
+ xs = [box[i] for i in range(0, len(box), 2)]
121
+ ys = [box[i] for i in range(1, len(box), 2)]
122
+ return min(xs), min(ys), max(xs), max(ys)
123
+
124
+
125
+ def is_box_fully_inside(box: List[float], crop_x: int, crop_y: int, crop_size: int) -> bool:
126
+ """检查框是否完全在 crop 区域内"""
127
+ x1, y1, x2, y2 = get_box_bounds(box)
128
+ cx1, cy1 = crop_x, crop_y
129
+ cx2, cy2 = crop_x + crop_size, crop_y + crop_size
130
+ return x1 >= cx1 and y1 >= cy1 and x2 <= cx2 and y2 <= cy2
131
+
132
+
133
+ def calc_box_overlap(box: List[float], crop_x: int, crop_y: int, crop_size: int) -> float:
134
+ """
135
+ 计算文本框与 crop 区域的重叠比例
136
+ 返回值: 0.0 - 1.0,表示文本框有多少比例在 crop 区域内
137
+ """
138
+ x1, y1, x2, y2 = get_box_bounds(box)
139
+ box_area = (x2 - x1) * (y2 - y1)
140
+ if box_area <= 0:
141
+ return 0.0
142
+
143
+ # crop 区域边界
144
+ cx1, cy1 = crop_x, crop_y
145
+ cx2, cy2 = crop_x + crop_size, crop_y + crop_size
146
+
147
+ # 计算交集
148
+ inter_x1 = max(x1, cx1)
149
+ inter_y1 = max(y1, cy1)
150
+ inter_x2 = min(x2, cx2)
151
+ inter_y2 = min(y2, cy2)
152
+
153
+ if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
154
+ return 0.0
155
+
156
+ inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
157
+ return inter_area / box_area
158
+
159
+
160
+ def clip_polygon_to_crop(poly: List[float], crop_x: int, crop_y: int, crop_size: int) -> List[float]:
161
+ """
162
+ 将多边形坐标转换到 crop 坐标系
163
+
164
+ 当 min_box_overlap = 1.0 时,框完全在 crop 内,只需要平移
165
+ 当 min_box_overlap < 1.0 时,需要将超出边界的坐标裁剪到边界
166
+ """
167
+ clipped = []
168
+ for i in range(0, len(poly), 2):
169
+ x = poly[i] - crop_x # 先平移
170
+ y = poly[i+1] - crop_y
171
+
172
+ # 裁剪到 [0, crop_size] 范围内(安全措施)
173
+ x = max(0, min(x, crop_size))
174
+ y = max(0, min(y, crop_size))
175
+
176
+ clipped.append(x)
177
+ clipped.append(y)
178
+
179
+ return clipped
180
+
181
+
182
+ def find_best_crop_position(
183
+ polygons: List[List[float]],
184
+ img_w: int,
185
+ img_h: int,
186
+ crop_size: int,
187
+ strategy: str = 'text_center',
188
+ min_overlap: float = 0.7
189
+ ) -> Tuple[int, int, List[int]]:
190
+ """
191
+ 找到最佳的 crop 位置
192
+
193
+ Args:
194
+ polygons: 文本框多边形列表
195
+ img_w, img_h: 原图尺寸
196
+ crop_size: crop 尺寸
197
+ strategy: crop 策略
198
+ min_overlap: 最小重叠比例,文本框至少这么多比例在 crop 内才保留
199
+
200
+ Returns:
201
+ (crop_x, crop_y, valid_box_indices)
202
+ """
203
+ if not polygons:
204
+ crop_x = max(0, (img_w - crop_size) // 2)
205
+ crop_y = max(0, (img_h - crop_size) // 2)
206
+ return crop_x, crop_y, []
207
+
208
+ if strategy == 'center':
209
+ crop_x = max(0, (img_w - crop_size) // 2)
210
+ crop_y = max(0, (img_h - crop_size) // 2)
211
+
212
+ elif strategy == 'text_center':
213
+ # 计算所有文本框的中心,以平均中心为 crop 中心
214
+ centers_x = []
215
+ centers_y = []
216
+ for poly in polygons:
217
+ x1, y1, x2, y2 = get_box_bounds(poly)
218
+ centers_x.append((x1 + x2) / 2)
219
+ centers_y.append((y1 + y2) / 2)
220
+
221
+ avg_cx = sum(centers_x) / len(centers_x)
222
+ avg_cy = sum(centers_y) / len(centers_y)
223
+
224
+ crop_x = int(avg_cx - crop_size / 2)
225
+ crop_y = int(avg_cy - crop_size / 2)
226
+
227
+ # 边界检查
228
+ crop_x = max(0, min(crop_x, img_w - crop_size))
229
+ crop_y = max(0, min(crop_y, img_h - crop_size))
230
+
231
+ elif strategy == 'random':
232
+ max_x = max(0, img_w - crop_size)
233
+ max_y = max(0, img_h - crop_size)
234
+ crop_x = random.randint(0, max_x) if max_x > 0 else 0
235
+ crop_y = random.randint(0, max_y) if max_y > 0 else 0
236
+
237
+ else:
238
+ crop_x = max(0, (img_w - crop_size) // 2)
239
+ crop_y = max(0, (img_h - crop_size) // 2)
240
+
241
+ # 找出哪些文本框的重叠比例 >= min_overlap
242
+ valid_indices = []
243
+ for i, poly in enumerate(polygons):
244
+ if min_overlap >= 0.99:
245
+ # 严格模式:只保留完全在 crop 区域内的框
246
+ if is_box_fully_inside(poly, crop_x, crop_y, crop_size):
247
+ valid_indices.append(i)
248
+ else:
249
+ # 宽松模式:保留 overlap >= min_overlap 的框
250
+ overlap = calc_box_overlap(poly, crop_x, crop_y, crop_size)
251
+ if overlap >= min_overlap:
252
+ valid_indices.append(i)
253
+
254
+ return crop_x, crop_y, valid_indices
255
+
256
+
257
+ def extract_frame_with_crop(
258
+ video_path: Path,
259
+ frame_idx: int,
260
+ output_dir: Path,
261
+ crop_x: int,
262
+ crop_y: int,
263
+ crop_size: int,
264
+ ) -> Tuple[str, bool]:
265
+ """
266
+ 从视频中提取指定帧并 crop
267
+
268
+ Returns:
269
+ (saved_filename, success)
270
+ """
271
+ cap = cv2.VideoCapture(str(video_path))
272
+ if not cap.isOpened():
273
+ return '', False
274
+
275
+ # 跳到指定帧
276
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
277
+ ret, frame = cap.read()
278
+ cap.release()
279
+
280
+ if not ret:
281
+ return '', False
282
+
283
+ # 转换为 RGB
284
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
285
+ img = Image.fromarray(frame_rgb)
286
+
287
+ # Crop
288
+ img_cropped = img.crop((crop_x, crop_y, crop_x + crop_size, crop_y + crop_size))
289
+
290
+ # 保存
291
+ video_name = video_path.stem
292
+ img_filename = f"{video_name}-{frame_idx:07d}.png"
293
+ img_path = output_dir / img_filename
294
+ img_cropped.save(img_path)
295
+
296
+ return img_filename, True
297
+
298
+
299
+ def adjust_polygon_for_crop(poly: List[float], crop_x: int, crop_y: int) -> List[float]:
300
+ """调整多边形坐标到 crop 后的坐标系"""
301
+ adjusted = []
302
+ for i in range(0, len(poly), 2):
303
+ adjusted.append(poly[i] - crop_x)
304
+ adjusted.append(poly[i+1] - crop_y)
305
+ return adjusted
306
+
307
+
308
+ def merge_annotations(loc_ann: Dict, text_dict: Dict) -> Dict:
309
+ """
310
+ 合并 Localisation 和 Text_Transcription
311
+
312
+ Args:
313
+ loc_ann: Localisation 标注项
314
+ text_dict: Text_Transcription 字典 {video_name: {label_id: text}}
315
+
316
+ Returns:
317
+ 合并后的标注: {
318
+ 'polygons': [[x1,y1,x2,y2,x3,y3,x4,y4], ...],
319
+ 'texts': ['text1', 'text2', ...],
320
+ 'ignore': [False, False, ...]
321
+ }
322
+ """
323
+ video_name = loc_ann.get('videoName', '')
324
+ labels = loc_ann.get('labels') or [] # 处理 None 的情况
325
+
326
+ # 获取该视频的文本字典
327
+ video_texts = text_dict.get(video_name, {})
328
+
329
+ polygons = []
330
+ texts = []
331
+ ignore = []
332
+
333
+ for label in labels:
334
+ label_id = str(label.get('id', ''))
335
+ box2d = label.get('box2d', {})
336
+
337
+ if not box2d:
338
+ continue
339
+
340
+ # 从 box2d 转换为多边形 (4个点)
341
+ x1, y1 = box2d.get('x1', 0), box2d.get('y1', 0)
342
+ x2, y2 = box2d.get('x2', 0), box2d.get('y2', 0)
343
+
344
+ # 转换为多边形格式 [x1,y1,x2,y1,x2,y2,x1,y2]
345
+ polygon = [x1, y1, x2, y1, x2, y2, x1, y2]
346
+ polygons.append(polygon)
347
+
348
+ # 获取文本
349
+ text = video_texts.get(label_id, '')
350
+ texts.append(text)
351
+
352
+ # 判断是否忽略 (空文本或特定标记)
353
+ ignore_flag = (text == '' or text == '###' or text.lower() == 'illegible')
354
+ ignore.append(ignore_flag)
355
+
356
+ return {
357
+ 'polygons': polygons,
358
+ 'texts': texts,
359
+ 'ignore': ignore,
360
+ }
361
+
362
+
363
+ def scale_polygon(polygon: List[float], scale_x: float, scale_y: float) -> List[float]:
364
+ """缩放多边形坐标"""
365
+ scaled = []
366
+ for i in range(0, len(polygon), 2):
367
+ scaled.append(polygon[i] * scale_x)
368
+ scaled.append(polygon[i+1] * scale_y)
369
+ return scaled
370
+
371
+
372
+ def main():
373
+ random.seed(CONFIG['seed'])
374
+ np.random.seed(CONFIG['seed'])
375
+
376
+ root_dir = Path(CONFIG['roadtext_root'])
377
+ output_dir = Path(CONFIG['output_dir'])
378
+ split = CONFIG['split']
379
+ crop_size = CONFIG['crop_size']
380
+
381
+ # 创建输出目录
382
+ gt_dir = output_dir / 'gt'
383
+ gt_dir.mkdir(parents=True, exist_ok=True)
384
+
385
+ print(f"RoadText1K root: {root_dir}")
386
+ print(f"Split: {split}")
387
+ print(f"Output dir: {output_dir}")
388
+ print(f"Crop size: {crop_size}x{crop_size}")
389
+ print(f"Crop strategy: {CONFIG['crop_strategy']}")
390
+ print(f"Min text boxes per crop: {CONFIG['min_text_boxes']}")
391
+ print()
392
+
393
+ # 加载标注
394
+ loc_annotations = load_localisation_annotations(root_dir)
395
+ text_transcriptions = load_text_transcriptions(root_dir)
396
+
397
+ # 获取视频目录
398
+ video_dir = root_dir / 'Videos' / split
399
+ if not video_dir.exists():
400
+ print(f"Error: Video directory not found: {video_dir}")
401
+ return
402
+
403
+ # 获取所有视频文件(支持嵌套子目录结构)
404
+ video_files = []
405
+ for subdir in sorted(video_dir.iterdir()):
406
+ if subdir.is_dir():
407
+ # 子目录中的视频文件
408
+ video_files.extend(subdir.glob('*.mp4'))
409
+ video_files.extend(subdir.glob('*.avi'))
410
+ elif subdir.suffix in ['.mp4', '.avi']:
411
+ # 视频文件直接在 split 目录下
412
+ video_files.append(subdir)
413
+
414
+ video_files = sorted(video_files)
415
+ total_videos = len(video_files)
416
+ print(f"找到 {total_videos} 个视频")
417
+ print()
418
+
419
+ # 处理视频
420
+ new_annotations = {}
421
+ processed_count = 0
422
+ total_frames = 0
423
+
424
+ print("Processing videos...")
425
+ print("Step 1: 找出每个视频中有标注的帧...")
426
+
427
+ # 第一步:找出每个视频中有标注的帧索引,同时记录标注 key
428
+ video_annotated_frames = {} # {video_name: [(frame_num, annotation_key), ...]}
429
+ video_name_to_path = {vp.stem: vp for vp in video_files}
430
+
431
+ for key, ann in loc_annotations.items():
432
+ # 从标注的 name 中提取视频名和帧号
433
+ # 格式可能是: "200_frames/170-0000001.jpg" 或 "test_frames/701-0000001.jpg"
434
+ parts = key.split('/')
435
+ if len(parts) >= 2:
436
+ filename = parts[-1] # "170-0000001.jpg"
437
+ else:
438
+ filename = key # "170-0000001.jpg"
439
+
440
+ # 提取帧号
441
+ if '-' in filename:
442
+ try:
443
+ frame_str = filename.split('-')[-1].split('.')[0]
444
+ frame_num = int(frame_str)
445
+
446
+ # 尝试提取视频名
447
+ video_name_candidate = filename.split('-')[0]
448
+
449
+ # 检查是否在我们的视频列表中
450
+ if video_name_candidate in video_name_to_path:
451
+ if video_name_candidate not in video_annotated_frames:
452
+ video_annotated_frames[video_name_candidate] = []
453
+ # 存储 (frame_num, annotation_key) 以便后续使用
454
+ entry = (frame_num, key)
455
+ if entry not in video_annotated_frames[video_name_candidate]:
456
+ video_annotated_frames[video_name_candidate].append(entry)
457
+ except:
458
+ pass
459
+
460
+ print(f"找到 {len(video_annotated_frames)} 个视频有标注")
461
+ total_annotated_frames = sum(len(frames) for frames in video_annotated_frames.values())
462
+ print(f"总共有 {total_annotated_frames} 帧有标注")
463
+
464
+ # 如果设置了 max_frames,随机选取
465
+ max_frames = CONFIG['max_frames']
466
+ if max_frames is not None and total_annotated_frames > max_frames:
467
+ print(f"\n限制最大帧数为 {max_frames},随机选取中...")
468
+
469
+ # 将所有帧展开为列表 [(video_name, frame_num, annotation_key), ...]
470
+ all_frames = []
471
+ for video_name, frame_list in video_annotated_frames.items():
472
+ for frame_num, ann_key in frame_list:
473
+ all_frames.append((video_name, frame_num, ann_key))
474
+
475
+ # 随机选取
476
+ selected_frames = random.sample(all_frames, max_frames)
477
+
478
+ # 重新组织为 video_annotated_frames 格式
479
+ video_annotated_frames = {}
480
+ for video_name, frame_num, ann_key in selected_frames:
481
+ if video_name not in video_annotated_frames:
482
+ video_annotated_frames[video_name] = []
483
+ video_annotated_frames[video_name].append((frame_num, ann_key))
484
+
485
+ print(f"随机选取 {max_frames} 帧,涉及 {len(video_annotated_frames)} 个视频")
486
+
487
+ crop_size = CONFIG['crop_size']
488
+ crop_strategy = CONFIG['crop_strategy']
489
+ min_box_overlap = CONFIG['min_box_overlap']
490
+ min_text_boxes = CONFIG['min_text_boxes']
491
+
492
+ print()
493
+ print("Step 2: 提取帧并 Crop(保留有效文本框)...")
494
+ print(f" Crop 尺寸: {crop_size}x{crop_size}")
495
+ print(f" Crop 策略: {crop_strategy}")
496
+ print(f" 最小重叠比例: {min_box_overlap:.0%}")
497
+ print(f" 最小文本框数: {min_text_boxes}")
498
+ print()
499
+
500
+ skipped_no_boxes = 0
501
+
502
+ for video_path in tqdm(video_files, desc="Videos"):
503
+ video_name = video_path.stem
504
+
505
+ # 获取该视频有标注的帧信息 [(frame_num, ann_key), ...]
506
+ annotated_frame_info = video_annotated_frames.get(video_name, [])
507
+
508
+ if not annotated_frame_info:
509
+ continue
510
+
511
+ # 获取视频分辨率
512
+ cap = cv2.VideoCapture(str(video_path))
513
+ if not cap.isOpened():
514
+ continue
515
+ orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
516
+ orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
517
+ cap.release()
518
+
519
+ if orig_w == 0 or orig_h == 0:
520
+ continue
521
+
522
+ # 处理每一帧
523
+ for frame_num, ann_key in annotated_frame_info:
524
+ if ann_key not in loc_annotations:
525
+ continue
526
+
527
+ loc_ann = loc_annotations[ann_key]
528
+ if loc_ann is None:
529
+ continue
530
+
531
+ # 合并标注
532
+ merged_ann = merge_annotations(loc_ann, text_transcriptions)
533
+
534
+ if not merged_ann['polygons']:
535
+ skipped_no_boxes += 1
536
+ continue
537
+
538
+ # 找到最佳的 crop 位置
539
+ crop_x, crop_y, valid_indices = find_best_crop_position(
540
+ merged_ann['polygons'],
541
+ orig_w,
542
+ orig_h,
543
+ crop_size,
544
+ crop_strategy,
545
+ min_box_overlap
546
+ )
547
+
548
+ # 检查有效文本框数量
549
+ if len(valid_indices) < min_text_boxes:
550
+ skipped_no_boxes += 1
551
+ continue
552
+
553
+ # 提取并 crop 帧
554
+ img_filename, success = extract_frame_with_crop(
555
+ video_path, frame_num, gt_dir,
556
+ crop_x, crop_y, crop_size
557
+ )
558
+
559
+ if not success:
560
+ continue
561
+
562
+ # 只保留有效的文本框,裁剪并调整坐标
563
+ cropped_polygons = []
564
+ cropped_texts = []
565
+ cropped_ignore = []
566
+
567
+ for i in valid_indices:
568
+ # 裁剪多边形到 crop 区域并转换坐标
569
+ clipped_poly = clip_polygon_to_crop(
570
+ merged_ann['polygons'][i], crop_x, crop_y, crop_size
571
+ )
572
+ cropped_polygons.append(clipped_poly)
573
+ cropped_texts.append(merged_ann['texts'][i])
574
+ cropped_ignore.append(merged_ann['ignore'][i])
575
+
576
+ new_annotations[img_filename] = {
577
+ 'polygons': cropped_polygons,
578
+ 'texts': cropped_texts,
579
+ 'ignore': cropped_ignore,
580
+ 'original_name': ann_key,
581
+ 'crop_position': [crop_x, crop_y],
582
+ }
583
+
584
+ total_frames += 1
585
+
586
+ processed_count += 1
587
+
588
+ # 保存标注
589
+ ann_output_path = output_dir / 'annotations.json'
590
+ with open(ann_output_path, 'w', encoding='utf-8') as f:
591
+ json.dump(new_annotations, f, indent=2, ensure_ascii=False)
592
+
593
+ # 统计有效文本框
594
+ total_boxes = sum(len(ann['texts']) for ann in new_annotations.values())
595
+ valid_boxes = sum(
596
+ sum(1 for ig in ann['ignore'] if not ig)
597
+ for ann in new_annotations.values()
598
+ )
599
+
600
+ print()
601
+ print("="*60)
602
+ print("完成!")
603
+ print("="*60)
604
+ print(f"处理视频数: {processed_count}")
605
+ print(f"提取帧数: {total_frames}")
606
+ print(f"跳过帧数(文本框不足): {skipped_no_boxes}")
607
+ print(f"有标注的图片数: {len(new_annotations)}")
608
+ print(f"文本框总数: {total_boxes}")
609
+ print(f"有效文本框: {valid_boxes}")
610
+ print(f"每张图平均文本框: {total_boxes/max(1,len(new_annotations)):.1f}")
611
+ print()
612
+ print("输出文件:")
613
+ print(f" GT images ({crop_size}x{crop_size}): {gt_dir}")
614
+ print(f" Annotations: {ann_output_path}")
615
+ print()
616
+ print("下一步:")
617
+ print(f" 1. 用你的方式生成 LR images (如 128x128)")
618
+ print(f" 2. 超分得到 SR images ({crop_size}x{crop_size})")
619
+ print(f" 3. 将 SR images 保存到 {output_dir}/sr")
620
+ print(f" 4. 运行 test_roadtext.py 评估")
621
+
622
+
623
+ if __name__ == '__main__':
624
+ main()
625
+
diffusion-dpo-ocr/results/roadtext_eval_results_BSRGAN.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 51,
3
+ "FP": 80,
4
+ "FN": 1920,
5
+ "Precision": 0.3893129770992366,
6
+ "Recall": 0.0258751902587519,
7
+ "F1-Score": 0.04852521408182683,
8
+ "OCR_detections": 131,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.06646372399797057,
11
+ "Det_Precision": 0.6183206106870229,
12
+ "Det_Recall": 0.0410958904109589,
13
+ "Det_F1": 0.07706945765937202,
14
+ "Det_matched": 81,
15
+ "Text_matched": 51,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_DP2O-SR.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 314,
3
+ "FP": 1182,
4
+ "FN": 1657,
5
+ "Precision": 0.20989304812834225,
6
+ "Recall": 0.15930999492643327,
7
+ "F1-Score": 0.181136429189501,
8
+ "OCR_detections": 1496,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.7590055809233891,
11
+ "Det_Precision": 0.5274064171122995,
12
+ "Det_Recall": 0.4003044140030441,
13
+ "Det_F1": 0.4551485434092875,
14
+ "Det_matched": 789,
15
+ "Text_matched": 314,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_DiT4SR.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 226,
3
+ "FP": 903,
4
+ "FN": 1745,
5
+ "Precision": 0.20017714791851196,
6
+ "Recall": 0.11466260781329274,
7
+ "F1-Score": 0.14580645161290323,
8
+ "OCR_detections": 1129,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.5728056823947235,
11
+ "Det_Precision": 0.5234720992028343,
12
+ "Det_Recall": 0.2998477929984779,
13
+ "Det_F1": 0.3812903225806451,
14
+ "Det_matched": 591,
15
+ "Text_matched": 226,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_DiffBIR.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 172,
3
+ "FP": 687,
4
+ "FN": 1799,
5
+ "Precision": 0.20023282887077998,
6
+ "Recall": 0.08726534753932014,
7
+ "F1-Score": 0.1215547703180212,
8
+ "OCR_detections": 859,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.43581938102486045,
11
+ "Det_Precision": 0.5448195576251456,
12
+ "Det_Recall": 0.2374429223744292,
13
+ "Det_F1": 0.33074204946996466,
14
+ "Det_matched": 468,
15
+ "Text_matched": 172,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_FaithDiff.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 210,
3
+ "FP": 1068,
4
+ "FN": 1761,
5
+ "Precision": 0.1643192488262911,
6
+ "Recall": 0.106544901065449,
7
+ "F1-Score": 0.12927054478301014,
8
+ "OCR_detections": 1278,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.6484018264840182,
11
+ "Det_Precision": 0.4890453834115806,
12
+ "Det_Recall": 0.31709791983764585,
13
+ "Det_F1": 0.38473376423514927,
14
+ "Det_matched": 625,
15
+ "Text_matched": 210,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_Ours.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 233,
3
+ "FP": 1483,
4
+ "FN": 1738,
5
+ "Precision": 0.1357808857808858,
6
+ "Recall": 0.11821410451547437,
7
+ "F1-Score": 0.12639001898562516,
8
+ "OCR_detections": 1716,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.8706240487062404,
11
+ "Det_Precision": 0.3275058275058275,
12
+ "Det_Recall": 0.28513444951801115,
13
+ "Det_F1": 0.3048548955790616,
14
+ "Det_matched": 562,
15
+ "Text_matched": 233,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_Real-ESRGAN.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 64,
3
+ "FP": 191,
4
+ "FN": 1907,
5
+ "Precision": 0.25098039215686274,
6
+ "Recall": 0.032470826991374935,
7
+ "F1-Score": 0.05750224618149146,
8
+ "OCR_detections": 255,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.1293759512937595,
11
+ "Det_Precision": 0.5882352941176471,
12
+ "Det_Recall": 0.076103500761035,
13
+ "Det_F1": 0.1347708894878706,
14
+ "Det_matched": 150,
15
+ "Text_matched": 64,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_SUPSR.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 219,
3
+ "FP": 1267,
4
+ "FN": 1752,
5
+ "Precision": 0.14737550471063257,
6
+ "Recall": 0.1111111111111111,
7
+ "F1-Score": 0.126699450390512,
8
+ "OCR_detections": 1486,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.7539320142059868,
11
+ "Det_Precision": 0.4878869448183042,
12
+ "Det_Recall": 0.3678335870116692,
13
+ "Det_F1": 0.41943881978594155,
14
+ "Det_matched": 725,
15
+ "Text_matched": 219,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_SeeSR.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 0,
3
+ "FP": 0,
4
+ "FN": 0,
5
+ "Precision": 0,
6
+ "Recall": 0,
7
+ "F1-Score": 0,
8
+ "OCR_detections": 0,
9
+ "GT_boxes": 0,
10
+ "Detection_rate": 0,
11
+ "Det_Precision": 0,
12
+ "Det_Recall": 0,
13
+ "Det_F1": 0,
14
+ "Det_matched": 0,
15
+ "Text_matched": 0,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_StableSR.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 141,
3
+ "FP": 672,
4
+ "FN": 1830,
5
+ "Precision": 0.17343173431734318,
6
+ "Recall": 0.0715372907153729,
7
+ "F1-Score": 0.10129310344827586,
8
+ "OCR_detections": 813,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.4124809741248097,
11
+ "Det_Precision": 0.5276752767527675,
12
+ "Det_Recall": 0.2176560121765601,
13
+ "Det_F1": 0.3081896551724138,
14
+ "Det_matched": 429,
15
+ "Text_matched": 141,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_SwinIR.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 66,
3
+ "FP": 182,
4
+ "FN": 1905,
5
+ "Precision": 0.2661290322580645,
6
+ "Recall": 0.0334855403348554,
7
+ "F1-Score": 0.05948625506985128,
8
+ "OCR_detections": 248,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.12582445459157787,
11
+ "Det_Precision": 0.5806451612903226,
12
+ "Det_Recall": 0.0730593607305936,
13
+ "Det_F1": 0.12978819287967552,
14
+ "Det_matched": 144,
15
+ "Text_matched": 66,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_gt.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 953,
3
+ "FP": 576,
4
+ "FN": 1018,
5
+ "Precision": 0.6232831916285154,
6
+ "Recall": 0.4835109081684424,
7
+ "F1-Score": 0.5445714285714285,
8
+ "OCR_detections": 1529,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.7757483510908169,
11
+ "Det_Precision": 0.6487900588620014,
12
+ "Det_Recall": 0.5032978183663115,
13
+ "Det_F1": 0.5668571428571428,
14
+ "Det_matched": 992,
15
+ "Text_matched": 953,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_sample00.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 0,
3
+ "FP": 0,
4
+ "FN": 0,
5
+ "Precision": 0,
6
+ "Recall": 0,
7
+ "F1-Score": 0,
8
+ "OCR_detections": 0,
9
+ "GT_boxes": 0,
10
+ "Detection_rate": 0,
11
+ "Det_Precision": 0,
12
+ "Det_Recall": 0,
13
+ "Det_F1": 0,
14
+ "Det_matched": 0,
15
+ "Text_matched": 0,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/results/roadtext_eval_results_zoomlr.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 8,
3
+ "FP": 106,
4
+ "FN": 1963,
5
+ "Precision": 0.07017543859649122,
6
+ "Recall": 0.004058853373921867,
7
+ "F1-Score": 0.007673860911270983,
8
+ "OCR_detections": 114,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.0578386605783866,
11
+ "Det_Precision": 0.07017543859649122,
12
+ "Det_Recall": 0.004058853373921867,
13
+ "Det_F1": 0.007673860911270983,
14
+ "Det_matched": 8,
15
+ "Text_matched": 8,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/roadtext_eval_results_output.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TP": 167,
3
+ "FP": 707,
4
+ "FN": 1804,
5
+ "Precision": 0.19107551487414187,
6
+ "Recall": 0.08472856418061897,
7
+ "F1-Score": 0.11739894551845341,
8
+ "OCR_detections": 874,
9
+ "GT_boxes": 1971,
10
+ "Detection_rate": 0.44342973110096395,
11
+ "Det_Precision": 0.540045766590389,
12
+ "Det_Recall": 0.23947234906139014,
13
+ "Det_F1": 0.3318101933216168,
14
+ "Det_matched": 472,
15
+ "Text_matched": 167,
16
+ "eval_mode": "end2end"
17
+ }
diffusion-dpo-ocr/test_roadtext.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RoadText1K OCR 评估脚本
3
+ 评估超分图片在 OCR 任务上的 Precision 和 Recall
4
+
5
+ Metrics:
6
+ - Precision: TP / (TP + FP)
7
+ - Recall: TP / (TP + FN)
8
+ - F1-Score: 2 * Precision * Recall / (Precision + Recall)
9
+ """
10
+
11
+ import os
12
+ import json
13
+ from pathlib import Path
14
+ from typing import List, Dict, Tuple
15
+ import difflib
16
+
17
+ import numpy as np
18
+ from PIL import Image
19
+ from tqdm import tqdm
20
+
21
+
22
+ # ============================================================================
23
+ # 配置参数 - 请修改这里
24
+ # ============================================================================
25
+ CONFIG = {
26
+ # SR images 目录
27
+ 'sr_dir': '/home/wanghongbo06/baipurui/results/RoadText/DreamClear/results/output',
28
+ # 'sr_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/SR-Eval/zoomlr',
29
+
30
+ # 标注文件 (prepare_roadtext.py 生成的)
31
+ 'annotation_file': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/annotations.json',
32
+
33
+ # OCR 引擎选择: 'paddleocr', 'easyocr', 或 'tesseract'
34
+ 'ocr_engine': 'paddleocr',
35
+
36
+ # 匹配参数
37
+ 'iou_threshold': 0.3, # 检测框 IoU 阈值
38
+ 'text_similarity_threshold': 0.3, # 文本相似度阈值 (0 = 只看检测,不看识别)
39
+
40
+ # 评估模式:
41
+ # 'detection_only' - 只评估检测 (忽略文本内容)
42
+ # 'end2end' - 端到端评估 (检测 + 识别)
43
+ # 'recognition_only' - 只评估识别 (在GT框上裁剪后识别)
44
+ 'eval_mode': 'end2end',
45
+
46
+ # OCR 配置
47
+ 'device': 'gpu', # 'gpu' 或 'cpu'
48
+
49
+ # 调试选项
50
+ 'debug_visualize': True, # 是否可视化前几张图的检测框
51
+ 'debug_save_dir': './debug_vis', # 可视化保存目录
52
+
53
+ # 输出
54
+ 'output': './roadtext_eval_results.json',
55
+ }
56
+ # ============================================================================
57
+
58
+
59
+ def load_ocr_model(engine='paddleocr', device='gpu'):
60
+ """加载 OCR 模型"""
61
+ if engine == 'paddleocr':
62
+ try:
63
+ from paddleocr import PaddleOCR
64
+ print("Loading PaddleOCR...")
65
+ ocr = PaddleOCR(
66
+ lang='en',
67
+ device=device,
68
+ )
69
+ return ocr, 'paddleocr'
70
+ except ImportError:
71
+ print("PaddleOCR not found. Install: pip install paddleocr")
72
+ raise
73
+
74
+ elif engine == 'easyocr':
75
+ try:
76
+ import easyocr
77
+ print("Loading EasyOCR...")
78
+ use_gpu = (device == 'gpu')
79
+ reader = easyocr.Reader(['en'], gpu=use_gpu)
80
+ return reader, 'easyocr'
81
+ except ImportError:
82
+ print("EasyOCR not found. Install: pip install easyocr")
83
+ raise
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported OCR engine: {engine}")
87
+
88
+
89
+ def run_ocr(image_path: str, ocr_model, engine: str) -> List[Dict]:
90
+ """
91
+ 运行 OCR
92
+
93
+ Returns:
94
+ List of dicts with keys: 'polygon', 'text', 'confidence'
95
+ """
96
+ results = []
97
+
98
+ if engine == 'paddleocr':
99
+ ocr_result = ocr_model.ocr(str(image_path))
100
+
101
+ if ocr_result and ocr_result[0]:
102
+ for line in ocr_result[0]:
103
+ polygon = np.array(line[0]).flatten().tolist() # [[x1,y1],[x2,y2],...] -> [x1,y1,x2,y2,...]
104
+ text = line[1][0]
105
+ confidence = line[1][1]
106
+
107
+ results.append({
108
+ 'polygon': polygon,
109
+ 'text': text,
110
+ 'confidence': confidence,
111
+ })
112
+
113
+ elif engine == 'easyocr':
114
+ ocr_result = ocr_model.readtext(str(image_path))
115
+
116
+ for detection in ocr_result:
117
+ polygon = np.array(detection[0]).flatten().tolist()
118
+ text = detection[1]
119
+ confidence = detection[2]
120
+
121
+ results.append({
122
+ 'polygon': polygon,
123
+ 'text': text,
124
+ 'confidence': confidence,
125
+ })
126
+
127
+ return results
128
+
129
+
130
+ def polygon_iou(poly1: List[float], poly2: List[float]) -> float:
131
+ """
132
+ 计算两个多边形的 IoU
133
+ poly: [x1,y1,x2,y2,x3,y3,x4,y4]
134
+ """
135
+ try:
136
+ from shapely.geometry import Polygon
137
+
138
+ # 转换为 Polygon 对象
139
+ p1 = Polygon([(poly1[i], poly1[i+1]) for i in range(0, len(poly1), 2)])
140
+ p2 = Polygon([(poly2[i], poly2[i+1]) for i in range(0, len(poly2), 2)])
141
+
142
+ if not p1.is_valid or not p2.is_valid:
143
+ return 0.0
144
+
145
+ # 计算 IoU
146
+ intersection = p1.intersection(p2).area
147
+ union = p1.union(p2).area
148
+
149
+ if union == 0:
150
+ return 0.0
151
+
152
+ return intersection / union
153
+
154
+ except ImportError:
155
+ # 使用 bbox IoU 近似
156
+ return bbox_iou_from_polygon(poly1, poly2)
157
+
158
+
159
+ def bbox_iou_from_polygon(poly1: List[float], poly2: List[float]) -> float:
160
+ """使用 bbox 近似计算 IoU"""
161
+ # 获取 bbox
162
+ x1_min = min(poly1[0::2])
163
+ y1_min = min(poly1[1::2])
164
+ x1_max = max(poly1[0::2])
165
+ y1_max = max(poly1[1::2])
166
+
167
+ x2_min = min(poly2[0::2])
168
+ y2_min = min(poly2[1::2])
169
+ x2_max = max(poly2[0::2])
170
+ y2_max = max(poly2[1::2])
171
+
172
+ # 计算交集
173
+ inter_xmin = max(x1_min, x2_min)
174
+ inter_ymin = max(y1_min, y2_min)
175
+ inter_xmax = min(x1_max, x2_max)
176
+ inter_ymax = min(y1_max, y2_max)
177
+
178
+ if inter_xmax <= inter_xmin or inter_ymax <= inter_ymin:
179
+ return 0.0
180
+
181
+ inter_area = (inter_xmax - inter_xmin) * (inter_ymax - inter_ymin)
182
+ area1 = (x1_max - x1_min) * (y1_max - y1_min)
183
+ area2 = (x2_max - x2_min) * (y2_max - y2_min)
184
+ union_area = area1 + area2 - inter_area
185
+
186
+ return inter_area / union_area if union_area > 0 else 0.0
187
+
188
+
189
+ def text_similarity(text1: str, text2: str) -> float:
190
+ """计算文本相似度"""
191
+ text1 = text1.lower().strip()
192
+ text2 = text2.lower().strip()
193
+
194
+ if text1 == text2:
195
+ return 1.0
196
+
197
+ # 使用编辑距离
198
+ return difflib.SequenceMatcher(None, text1, text2).ratio()
199
+
200
+
201
+ def match_detections(
202
+ pred_results: List[Dict],
203
+ gt_annotations: Dict,
204
+ iou_thresh: float = 0.5,
205
+ text_sim_thresh: float = 0.5,
206
+ eval_mode: str = 'end2end',
207
+ ) -> Tuple[int, int, int, Dict]:
208
+ """
209
+ 匹配预测和GT
210
+
211
+ Args:
212
+ eval_mode: 'detection_only', 'end2end', 'recognition_only'
213
+
214
+ Returns:
215
+ (TP, FP, FN, details)
216
+ """
217
+ gt_polygons = gt_annotations['polygons']
218
+ gt_texts = gt_annotations['texts']
219
+ gt_ignore = gt_annotations.get('ignore', [False] * len(gt_texts))
220
+
221
+ matched_gt = set()
222
+ tp = 0
223
+ fp = 0
224
+
225
+ # 详细统计
226
+ details = {
227
+ 'det_matched': 0, # 检测匹配数 (IoU 通过)
228
+ 'text_matched': 0, # 文本匹配数
229
+ 'det_fp': 0, # 检测误检
230
+ }
231
+
232
+ # 遍历预测结果
233
+ for pred in pred_results:
234
+ pred_poly = pred['polygon']
235
+ pred_text = pred['text']
236
+
237
+ best_iou = 0
238
+ best_gt_idx = -1
239
+
240
+ # 找到最佳匹配的 GT
241
+ for gt_idx, (gt_poly, gt_text, ignore) in enumerate(zip(gt_polygons, gt_texts, gt_ignore)):
242
+ if ignore or gt_idx in matched_gt:
243
+ continue
244
+
245
+ iou = polygon_iou(pred_poly, gt_poly)
246
+
247
+ if iou > best_iou:
248
+ best_iou = iou
249
+ best_gt_idx = gt_idx
250
+
251
+ # 判断是否匹配成功
252
+ if best_iou >= iou_thresh and best_gt_idx >= 0:
253
+ details['det_matched'] += 1
254
+
255
+ if eval_mode == 'detection_only':
256
+ # 只看检测,不看文本
257
+ tp += 1
258
+ matched_gt.add(best_gt_idx)
259
+ else:
260
+ # 检查文本相似度
261
+ gt_text = gt_texts[best_gt_idx]
262
+ sim = text_similarity(pred_text, gt_text)
263
+
264
+ if sim >= text_sim_thresh:
265
+ tp += 1
266
+ matched_gt.add(best_gt_idx)
267
+ details['text_matched'] += 1
268
+ else:
269
+ # 根据评估模式决定是否计为 FP
270
+ if eval_mode == 'end2end':
271
+ fp += 1 # 位置对但文字错,计为 FP
272
+ else:
273
+ fp += 1 # 误检
274
+ details['det_fp'] += 1
275
+
276
+ # 计算 FN (未检测到的GT)
277
+ fn = 0
278
+ for gt_idx, ignore in enumerate(gt_ignore):
279
+ if not ignore and gt_idx not in matched_gt:
280
+ fn += 1
281
+
282
+ return tp, fp, fn, details
283
+
284
+
285
+ def visualize_detections(
286
+ img_path: Path,
287
+ gt_ann: Dict,
288
+ pred_results: List[Dict],
289
+ save_path: Path,
290
+ ):
291
+ """可视化 GT 和 OCR 检测框对比"""
292
+ import cv2
293
+
294
+ img = cv2.imread(str(img_path))
295
+ if img is None:
296
+ return
297
+
298
+ # 绘制 GT 框 (绿色)
299
+ for i, (poly, text, ignore) in enumerate(zip(
300
+ gt_ann['polygons'], gt_ann['texts'], gt_ann.get('ignore', [False] * len(gt_ann['texts']))
301
+ )):
302
+ if ignore:
303
+ continue
304
+ pts = np.array(poly).reshape(-1, 2).astype(np.int32)
305
+ cv2.polylines(img, [pts], True, (0, 255, 0), 2)
306
+ # 标注 GT 文本
307
+ x, y = int(pts[0][0]), int(pts[0][1]) - 5
308
+ cv2.putText(img, f"GT:{text[:10]}", (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
309
+
310
+ # 绘制 OCR 检测框 (红色)
311
+ for pred in pred_results:
312
+ poly = pred['polygon']
313
+ text = pred['text']
314
+ pts = np.array(poly).reshape(-1, 2).astype(np.int32)
315
+ cv2.polylines(img, [pts], True, (0, 0, 255), 2)
316
+ # 标注 OCR 文本
317
+ x, y = int(pts[0][0]), int(pts[0][1]) + 15
318
+ cv2.putText(img, f"OCR:{text[:10]}", (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
319
+
320
+ cv2.imwrite(str(save_path), img)
321
+
322
+
323
+ def evaluate_dataset(
324
+ sr_dir: str,
325
+ annotation_file: str,
326
+ ocr_model,
327
+ engine: str,
328
+ debug: bool = False,
329
+ ) -> Dict:
330
+ """评估整个数据集"""
331
+ sr_dir = Path(sr_dir)
332
+ eval_mode = CONFIG.get('eval_mode', 'end2end')
333
+ debug_visualize = CONFIG.get('debug_visualize', False)
334
+ debug_save_dir = Path(CONFIG.get('debug_save_dir', './debug_vis'))
335
+
336
+ if debug_visualize:
337
+ debug_save_dir.mkdir(parents=True, exist_ok=True)
338
+
339
+ # 加载标注
340
+ with open(annotation_file, 'r', encoding='utf-8') as f:
341
+ annotations = json.load(f)
342
+
343
+ # 调试:统计标注信息
344
+ total_gt_boxes = 0
345
+ total_ignored = 0
346
+ for img_name, ann in annotations.items():
347
+ total_gt_boxes += len(ann['texts'])
348
+ total_ignored += sum(ann.get('ignore', [False] * len(ann['texts'])))
349
+ print(f"\n标注统计: 总共 {len(annotations)} 张图片, {total_gt_boxes} 个文本框, {total_ignored} 个被忽略")
350
+ print(f"评估模式: {eval_mode}")
351
+
352
+ total_tp = 0
353
+ total_fp = 0
354
+ total_fn = 0
355
+
356
+ # 汇总详细统计
357
+ total_details = {
358
+ 'det_matched': 0,
359
+ 'text_matched': 0,
360
+ 'det_fp': 0,
361
+ }
362
+
363
+ print("Running OCR evaluation...")
364
+ total_ocr_detections = 0
365
+ debug_count = 0
366
+ vis_count = 0
367
+
368
+ for img_name, gt_ann in tqdm(annotations.items()):
369
+ img_path = sr_dir / img_name
370
+
371
+ if not img_path.exists():
372
+ # 尝试其他扩展名
373
+ img_path = sr_dir / (Path(img_name).stem + '.jpg')
374
+ if not img_path.exists():
375
+ continue
376
+
377
+ # 运行 OCR
378
+ pred_results = run_ocr(img_path, ocr_model, engine)
379
+ total_ocr_detections += len(pred_results)
380
+
381
+ # 调试:打印前3张图片的详细信息
382
+ if debug_count < 3:
383
+ print(f"\n[DEBUG] Image: {img_name}")
384
+ print(f" GT boxes: {len(gt_ann['polygons'])}, ignored: {sum(gt_ann.get('ignore', []))}")
385
+ print(f" OCR detections: {len(pred_results)}")
386
+ if gt_ann['polygons']:
387
+ print(f" GT[0] polygon: {gt_ann['polygons'][0][:4]}... text: '{gt_ann['texts'][0]}'")
388
+ if pred_results:
389
+ print(f" OCR[0] polygon: {pred_results[0]['polygon'][:4]}... text: '{pred_results[0]['text']}'")
390
+ debug_count += 1
391
+
392
+ # 可视化
393
+ if debug_visualize and vis_count < 20:
394
+ vis_path = debug_save_dir / f"vis_{img_name}"
395
+ visualize_detections(img_path, gt_ann, pred_results, vis_path)
396
+ vis_count += 1
397
+
398
+ # 匹配
399
+ tp, fp, fn, details = match_detections(
400
+ pred_results,
401
+ gt_ann,
402
+ iou_thresh=CONFIG['iou_threshold'],
403
+ text_sim_thresh=CONFIG['text_similarity_threshold'],
404
+ eval_mode=eval_mode,
405
+ )
406
+
407
+ total_tp += tp
408
+ total_fp += fp
409
+ total_fn += fn
410
+
411
+ for k, v in details.items():
412
+ total_details[k] += v
413
+
414
+ print(f"\nOCR 总共检测到 {total_ocr_detections} 个文本框")
415
+
416
+ # 计算指标
417
+ precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
418
+ recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
419
+ f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
420
+
421
+ # 统计 GT 信息 (不含 ignore)
422
+ total_gt_boxes_valid = total_tp + total_fn
423
+
424
+ # 计算纯检测指标 (不考虑文本内容)
425
+ det_precision = total_details['det_matched'] / total_ocr_detections if total_ocr_detections > 0 else 0
426
+ det_recall = total_details['det_matched'] / total_gt_boxes_valid if total_gt_boxes_valid > 0 else 0
427
+ det_f1 = 2 * det_precision * det_recall / (det_precision + det_recall) if (det_precision + det_recall) > 0 else 0
428
+
429
+ return {
430
+ 'TP': total_tp,
431
+ 'FP': total_fp,
432
+ 'FN': total_fn,
433
+ 'Precision': precision,
434
+ 'Recall': recall,
435
+ 'F1-Score': f1_score,
436
+ 'OCR_detections': total_ocr_detections,
437
+ 'GT_boxes': total_gt_boxes_valid,
438
+ 'Detection_rate': total_ocr_detections / total_gt_boxes_valid if total_gt_boxes_valid > 0 else 0,
439
+ # 新增:纯检测指标
440
+ 'Det_Precision': det_precision,
441
+ 'Det_Recall': det_recall,
442
+ 'Det_F1': det_f1,
443
+ 'Det_matched': total_details['det_matched'],
444
+ 'Text_matched': total_details['text_matched'],
445
+ 'eval_mode': eval_mode,
446
+ }
447
+
448
+
449
+ def main():
450
+ # 自动生成 output 路径:根据 sr_dir 最后一个目录名
451
+ sr_dir = Path(CONFIG['sr_dir'])
452
+ baseline_name = sr_dir.name # 获取最后一个目录名,如 'sr', 'gt', 'bicubic' 等
453
+ output_path = Path(f"./roadtext_eval_results_{baseline_name}.json")
454
+
455
+ print("="*60)
456
+ print("RoadText1K OCR Evaluation")
457
+ print("="*60)
458
+ print(f"SR images: {CONFIG['sr_dir']}")
459
+ print(f"Annotations: {CONFIG['annotation_file']}")
460
+ print(f"OCR engine: {CONFIG['ocr_engine']}")
461
+ print(f"IoU threshold: {CONFIG['iou_threshold']}")
462
+ print(f"Text similarity threshold: {CONFIG['text_similarity_threshold']}")
463
+ print(f"Output will be saved to: {output_path}")
464
+ print()
465
+
466
+ # 加载 OCR 模型
467
+ ocr_model, engine = load_ocr_model(
468
+ CONFIG['ocr_engine'],
469
+ device=CONFIG['device']
470
+ )
471
+
472
+ # 评估
473
+ results = evaluate_dataset(
474
+ CONFIG['sr_dir'],
475
+ CONFIG['annotation_file'],
476
+ ocr_model,
477
+ engine,
478
+ )
479
+
480
+ # 打印结果
481
+ print("\n" + "="*60)
482
+ print("EVALUATION RESULTS")
483
+ print("="*60)
484
+ print(f"评估模式: {results.get('eval_mode', 'end2end')}")
485
+
486
+ print("\n[Detection Statistics]")
487
+ print(f" GT text boxes (valid): {results['GT_boxes']}")
488
+ print(f" OCR detections: {results['OCR_detections']}")
489
+ print(f" Detection matched (IoU): {results.get('Det_matched', 'N/A')}")
490
+ print(f" Text matched: {results.get('Text_matched', 'N/A')}")
491
+
492
+ print("\n[Detection-Only Metrics] (只看框位置,不看文字)")
493
+ print(f" Det Precision: {results.get('Det_Precision', 0)*100:.2f}%")
494
+ print(f" Det Recall: {results.get('Det_Recall', 0)*100:.2f}%")
495
+ print(f" Det F1-Score: {results.get('Det_F1', 0)*100:.2f}%")
496
+
497
+ print("\n[End-to-End Metrics] (检测 + 识别)")
498
+ print(f" True Positives: {results['TP']}")
499
+ print(f" False Positives: {results['FP']}")
500
+ print(f" False Negatives: {results['FN']}")
501
+ print(f" Precision: {results['Precision']*100:.2f}%")
502
+ print(f" Recall: {results['Recall']*100:.2f}%")
503
+ print(f" F1-Score: {results['F1-Score']*100:.2f}%")
504
+
505
+ # 保存结果
506
+ output_path.parent.mkdir(parents=True, exist_ok=True)
507
+ with open(output_path, 'w') as f:
508
+ json.dump(results, f, indent=2)
509
+ print(f"\nResults saved to {output_path}")
510
+
511
+
512
+ if __name__ == '__main__':
513
+ main()
514
+
diffusion-dpo-ocr/verify_roadtext_annotations.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 验证 RoadText1K 预处理结果的脚本
3
+ 可视化 GT 标注框与 crop 后图像的对应关系,检查坐标是否正确
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import random
9
+ from pathlib import Path
10
+ from typing import Dict, List
11
+
12
+ import cv2
13
+ import numpy as np
14
+ from PIL import Image
15
+
16
+
17
+ # ============================================================================
18
+ # 配置
19
+ # ============================================================================
20
+ CONFIG = {
21
+ # GT 图像目录
22
+ 'gt_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/gt',
23
+
24
+ # 标注文件路径
25
+ 'annotation_file': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/annotations.json',
26
+
27
+ # 可视化输出目录
28
+ 'vis_output_dir': './verify_roadtext_vis',
29
+
30
+ # 可视化图片数量
31
+ 'num_samples': 20,
32
+
33
+ # 随机种子
34
+ 'seed': 42,
35
+ }
36
+ # ============================================================================
37
+
38
+
39
+ def draw_annotations(img: np.ndarray, ann: Dict, color=(0, 255, 0), thickness=2) -> np.ndarray:
40
+ """在图像上绘制标注框"""
41
+ img_vis = img.copy()
42
+
43
+ polygons = ann.get('polygons', [])
44
+ texts = ann.get('texts', [])
45
+ ignores = ann.get('ignore', [False] * len(texts))
46
+
47
+ for i, (poly, text, ignore) in enumerate(zip(polygons, texts, ignores)):
48
+ if ignore:
49
+ box_color = (128, 128, 128) # 灰色 = ignore
50
+ else:
51
+ box_color = color
52
+
53
+ # 绘制多边形
54
+ pts = np.array(poly).reshape(-1, 2).astype(np.int32)
55
+ cv2.polylines(img_vis, [pts], True, box_color, thickness)
56
+
57
+ # 标注文本
58
+ x, y = int(pts[0][0]), int(pts[0][1]) - 5
59
+ if y < 15:
60
+ y = int(pts[0][1]) + 15
61
+
62
+ # 缩短文本显示
63
+ display_text = text[:15] + "..." if len(text) > 15 else text
64
+ cv2.putText(img_vis, display_text, (x, y),
65
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, box_color, 1)
66
+
67
+ return img_vis
68
+
69
+
70
+ def verify_single_image(
71
+ img_path: Path,
72
+ ann: Dict,
73
+ save_path: Path,
74
+ ) -> Dict:
75
+ """验证单张图片的标注"""
76
+ img = cv2.imread(str(img_path))
77
+ if img is None:
78
+ return {'error': f'Cannot read image: {img_path}'}
79
+
80
+ h, w = img.shape[:2]
81
+
82
+ # 统计信息
83
+ stats = {
84
+ 'img_size': (w, h),
85
+ 'num_boxes': len(ann.get('polygons', [])),
86
+ 'num_ignored': sum(ann.get('ignore', [])),
87
+ 'boxes_in_bounds': 0,
88
+ 'boxes_out_of_bounds': 0,
89
+ }
90
+
91
+ # 检查标注框是否在图像范围内
92
+ for poly in ann.get('polygons', []):
93
+ xs = [poly[i] for i in range(0, len(poly), 2)]
94
+ ys = [poly[i] for i in range(1, len(poly), 2)]
95
+
96
+ if min(xs) >= 0 and max(xs) <= w and min(ys) >= 0 and max(ys) <= h:
97
+ stats['boxes_in_bounds'] += 1
98
+ else:
99
+ stats['boxes_out_of_bounds'] += 1
100
+ print(f" [WARNING] Box out of bounds: x=[{min(xs):.1f}, {max(xs):.1f}], y=[{min(ys):.1f}, {max(ys):.1f}], img={w}x{h}")
101
+
102
+ # 绘制标注
103
+ img_vis = draw_annotations(img, ann)
104
+
105
+ # 保存
106
+ cv2.imwrite(str(save_path), img_vis)
107
+
108
+ return stats
109
+
110
+
111
+ def main():
112
+ random.seed(CONFIG['seed'])
113
+
114
+ gt_dir = Path(CONFIG['gt_dir'])
115
+ ann_file = Path(CONFIG['annotation_file'])
116
+ vis_dir = Path(CONFIG['vis_output_dir'])
117
+ vis_dir.mkdir(parents=True, exist_ok=True)
118
+
119
+ print("=" * 60)
120
+ print("RoadText1K 标注验证脚本")
121
+ print("=" * 60)
122
+ print(f"GT dir: {gt_dir}")
123
+ print(f"Annotations: {ann_file}")
124
+ print(f"Output: {vis_dir}")
125
+ print()
126
+
127
+ # 检查文件是否存在
128
+ if not ann_file.exists():
129
+ print(f"Error: Annotation file not found: {ann_file}")
130
+ return
131
+
132
+ if not gt_dir.exists():
133
+ print(f"Error: GT directory not found: {gt_dir}")
134
+ return
135
+
136
+ # 加载标注
137
+ with open(ann_file, 'r', encoding='utf-8') as f:
138
+ annotations = json.load(f)
139
+
140
+ print(f"总共 {len(annotations)} 张图片有标注")
141
+
142
+ # 全局统计
143
+ total_boxes = 0
144
+ total_ignored = 0
145
+ total_in_bounds = 0
146
+ total_out_of_bounds = 0
147
+
148
+ # 计算所有图片的统计
149
+ for img_name, ann in annotations.items():
150
+ total_boxes += len(ann.get('polygons', []))
151
+ total_ignored += sum(ann.get('ignore', []))
152
+
153
+ # 检查边界
154
+ for poly in ann.get('polygons', []):
155
+ xs = [poly[i] for i in range(0, len(poly), 2)]
156
+ ys = [poly[i] for i in range(1, len(poly), 2)]
157
+ if min(xs) >= 0 and max(xs) <= 512 and min(ys) >= 0 and max(ys) <= 512:
158
+ total_in_bounds += 1
159
+ else:
160
+ total_out_of_bounds += 1
161
+
162
+ print(f"\n全局统计:")
163
+ print(f" 总文本框: {total_boxes}")
164
+ print(f" 忽略框: {total_ignored}")
165
+ print(f" 有效框 (非忽略): {total_boxes - total_ignored}")
166
+ print(f" 框在图像范围内: {total_in_bounds}")
167
+ print(f" 框超出范围: {total_out_of_bounds}")
168
+
169
+ if total_out_of_bounds > 0:
170
+ print(f"\n⚠️ 有 {total_out_of_bounds} 个框超出图像范围!这可能是坐标转换问题。")
171
+
172
+ # 随机选取一些图片进行可视化
173
+ img_names = list(annotations.keys())
174
+ num_samples = min(CONFIG['num_samples'], len(img_names))
175
+ selected = random.sample(img_names, num_samples)
176
+
177
+ print(f"\n随机选取 {num_samples} 张图片进行可视化...")
178
+
179
+ for img_name in selected:
180
+ ann = annotations[img_name]
181
+ img_path = gt_dir / img_name
182
+
183
+ if not img_path.exists():
184
+ # 尝试其他扩展名
185
+ img_path = gt_dir / (Path(img_name).stem + '.jpg')
186
+
187
+ if not img_path.exists():
188
+ print(f" [SKIP] Image not found: {img_name}")
189
+ continue
190
+
191
+ save_path = vis_dir / f"verify_{img_name}"
192
+ stats = verify_single_image(img_path, ann, save_path)
193
+
194
+ if 'error' not in stats:
195
+ print(f" [OK] {img_name}: {stats['num_boxes']} boxes, "
196
+ f"{stats['boxes_out_of_bounds']} out of bounds")
197
+
198
+ print(f"\n可视化结果保存到: {vis_dir}")
199
+
200
+ # 额外检查:打印一些标注样例
201
+ print("\n" + "=" * 60)
202
+ print("标注样例 (前 3 张图):")
203
+ print("=" * 60)
204
+
205
+ for i, (img_name, ann) in enumerate(list(annotations.items())[:3]):
206
+ print(f"\n[{i+1}] {img_name}")
207
+ print(f" crop_position: {ann.get('crop_position', 'N/A')}")
208
+ print(f" num_boxes: {len(ann.get('polygons', []))}")
209
+
210
+ for j, (poly, text, ignore) in enumerate(zip(
211
+ ann.get('polygons', [])[:2], # 只显示前2个
212
+ ann.get('texts', [])[:2],
213
+ ann.get('ignore', [])[:2]
214
+ )):
215
+ xs = [poly[k] for k in range(0, len(poly), 2)]
216
+ ys = [poly[k] for k in range(1, len(poly), 2)]
217
+ print(f" Box {j}: x=[{min(xs):.1f}, {max(xs):.1f}], y=[{min(ys):.1f}, {max(ys):.1f}], "
218
+ f"text='{text[:20]}', ignore={ignore}")
219
+
220
+
221
+ if __name__ == '__main__':
222
+ main()
223
+
diffusion-dpo-test/DIAGNOSTIC_CHECKLIST.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🔍 完整诊断检查清单
2
+
3
+ ## 问题描述
4
+ 训练了 15 和 60 个 epoch,测试结果**逐像素完全相同**(262144 个 RGB 值一模一样)
5
+
6
+ ---
7
+
8
+ ## ✅ 已确认正常的部分
9
+
10
+ ### 1. 训练动态正常
11
+ - ✅ `acc` 和 `l_acc` 在变化
12
+ - ✅ `loss` 在下降
13
+ - ✅ `Max lora_B Check` 从 `1.04e-05` 增长到 `2.23e-03`(200+ 倍增长)
14
+
15
+ ### 2. 代码逻辑正常
16
+ - ✅ `disable_adapter()` 使用了 `with` 语句(已修复)
17
+ - ✅ VAE 编码在 `no_grad()` 中
18
+ - ✅ `ref_pred` 正确 detach
19
+ - ✅ LoRA 键名已清理(去除 `base_model.model.` 前缀)
20
+ - ✅ 优化器在正确的时机创建(dtype 转换之后)
21
+
22
+ ### 3. x_embedder 为 0 是正常的
23
+ - ✅ 输入层通常不需要针对 DPO 偏好优化
24
+ - ✅ 其他深层(`single_transformer_blocks.*`)的权重在增长
25
+ - ✅ LoRA 是加法,不是乘法,0 不会阻塞梯度流
26
+
27
+ ---
28
+
29
+ ## ❓ 需要检查的部分
30
+
31
+ ### 🎯 检查 1: Checkpoint 权重是否真的在变化?
32
+
33
+ **这是最关键的检查!**
34
+
35
+ ```bash
36
+ # 在训练机器上运行
37
+ cd <训练脚本运行目录> # 即 run.sh 的执行目录
38
+
39
+ # 1. 确认文件存在
40
+ ls -lh results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors
41
+ ls -lh results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
42
+
43
+ # 2. 对比权重
44
+ python compare_checkpoints.py \
45
+ results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors \
46
+ results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
47
+ ```
48
+
49
+ **预期结果**:
50
+ - ✅ **正常**: "50%+ 的参数在变化",最大差异 > 1e-4
51
+ - ❌ **异常**: "完全相同" 或 "< 10% 参数变化"
52
+
53
+ **如果异常,说明**:
54
+ - 保存逻辑有问题
55
+ - 或者训练根本没生效(但这与 loss 下降矛盾)
56
+
57
+ ---
58
+
59
+ ### 🎯 检查 2: 测试代码加载的是哪个 checkpoint?
60
+
61
+ **问题**: 测试代码路径和训练输出路径不匹配
62
+
63
+ 训练输出:
64
+ ```
65
+ results_1202_4/checkpoint-XX/lora_train_unet/adapter_model.safetensors
66
+ ```
67
+
68
+ 测试代码加载:
69
+ ```python
70
+ # diffusion-dpo-test/test.py line 33
71
+ "/home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/..."
72
+ ```
73
+
74
+ 注意 `results_1130` vs `results_1202_4`!
75
+
76
+ **验证方法**:
77
+ ```bash
78
+ # 在测试机器上
79
+ # 1. 确认测试代码实际加载的文件
80
+ ls -lh /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/lora_train_unet/adapter_model.safetensors
81
+ ls -lh /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-500/lora_train_unet/adapter_model.safetensors
82
+
83
+ # 2. 检查文件修改时间(是否是最新训练的)
84
+ stat /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/lora_train_unet/adapter_model.safetensors
85
+
86
+ # 3. 对比这两个文件
87
+ python compare_checkpoints.py \
88
+ /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/lora_train_unet/adapter_model.safetensors \
89
+ /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-500/lora_train_unet/adapter_model.safetensors
90
+ ```
91
+
92
+ **如果这两个文件完全相同**:
93
+ - 说明您测试的不是最新训练的模型
94
+ - 需要更新测试代码的路径,或者将最新的 checkpoint 复制到测试机器
95
+
96
+ ---
97
+
98
+ ### 🎯 检查 3: 测试代码的 LoRA 融合逻辑
99
+
100
+ 测试代码使用了**两次 fuse_lora**:
101
+
102
+ ```python
103
+ # 第一次:SR base LoRA
104
+ pipe.load_lora_weights("...pytorch_lora_weights_v2.safetensors", adapter_name="sr")
105
+ pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr"])
106
+ pipe.unload_lora_weights()
107
+
108
+ # 第二次:DPO trained LoRA
109
+ pipe.load_lora_weights("...adapter_model.safetensors", adapter_name="sr2")
110
+ pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr2"])
111
+ pipe.unload_lora_weights()
112
+ ```
113
+
114
+ **潜在问题**:
115
+ 1. `unload_lora_weights()` 可能清除了刚融合的权重
116
+ 2. 第二次 `fuse_lora` 可能没有正确叠加到第一次的结果上
117
+ 3. 如果两个 LoRA 的权重完全相同,输出自然也相同
118
+
119
+ **验证方法 A: 只加载 DPO LoRA**
120
+ ```python
121
+ # 临时修改 test.py
122
+ pipe = FluxPipeline.from_pretrained(...).to("cuda")
123
+
124
+ # 只加载第一个 LoRA
125
+ pipe.load_lora_weights(
126
+ "/home/wanghongbo06/baipurui/CKPTs/FLUX_SR/pytorch_lora_weights_v2.safetensors",
127
+ adapter_name="sr"
128
+ )
129
+ pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr"])
130
+ pipe.unload_lora_weights()
131
+
132
+ # 生成图像 -> 保存为 result_sr_only.png
133
+
134
+ # 然后只加载第二个 LoRA
135
+ pipe2 = FluxPipeline.from_pretrained(...).to("cuda")
136
+ pipe2.load_lora_weights(
137
+ "/home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-60/lora_train_unet/adapter_model.safetensors",
138
+ adapter_name="dpo"
139
+ )
140
+ pipe2.fuse_lora(lora_scale=1.0, adapter_names=["dpo"])
141
+
142
+ # 生成图像 -> 保存为 result_dpo_only.png
143
+ ```
144
+
145
+ 如果 `result_dpo_only.png` 在不同 epoch 之间也完全相同,说明问题在 checkpoint 本身。
146
+
147
+ **验证方法 B: 检查融合后的权重**
148
+ ```python
149
+ # 在 test.py 中添加调试代码
150
+ import torch
151
+
152
+ # 加载第一个 LoRA 后
153
+ pipe.load_lora_weights("...pytorch_lora_weights_v2.safetensors", adapter_name="sr")
154
+ pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr"])
155
+
156
+ # 保存融合后的 transformer 权重快照
157
+ transformer_weights_after_sr = {
158
+ k: v.clone() for k, v in pipe.transformer.state_dict().items()
159
+ if 'single_transformer_blocks.30' in k
160
+ }
161
+
162
+ pipe.unload_lora_weights()
163
+
164
+ # 加载第二个 LoRA 后
165
+ pipe.load_lora_weights("...adapter_model.safetensors", adapter_name="sr2")
166
+ pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr2"])
167
+
168
+ transformer_weights_after_dpo = {
169
+ k: v.clone() for k, v in pipe.transformer.state_dict().items()
170
+ if 'single_transformer_blocks.30' in k
171
+ }
172
+
173
+ # 对比
174
+ for k in transformer_weights_after_sr.keys():
175
+ diff = (transformer_weights_after_dpo[k] - transformer_weights_after_sr[k]).abs().max()
176
+ print(f"{k}: max_diff = {diff:.6e}")
177
+ ```
178
+
179
+ 如果所有 diff 都是 0,说明第二个 LoRA 没有被正确应用。
180
+
181
+ ---
182
+
183
+ ### 🎯 检查 4: Checkpoint 文件的完整性
184
+
185
+ ```bash
186
+ # 在训练机器上
187
+ python inspect_safetensor.py results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
188
+ ```
189
+
190
+ **预期结果**:
191
+ - 应该有 100+ 个参数张量
192
+ - `single_transformer_blocks.*` 的 lora_B 应该有非零值
193
+ - 非零参数比例应该 > 50%
194
+
195
+ **如果异常**:
196
+ - 文件损坏
197
+ - 或者保存逻辑有问题
198
+
199
+ ---
200
+
201
+ ## 🔧 可能的修复方案
202
+
203
+ ### 方案 1: 如果 checkpoint 权重没有变化
204
+ **原因**: 保存逻辑有问题
205
+
206
+ **修复**: 检查 `save_model_hook` 中的状态字典获取逻辑
207
+
208
+ ```python
209
+ # train_single_lora.py line 672
210
+ full_state_dict = accelerator.get_state_dict(model)
211
+ ```
212
+
213
+ 可能需要改为:
214
+ ```python
215
+ from peft import get_peft_model_state_dict
216
+ full_state_dict = get_peft_model_state_dict(model, adapter_name="train_unet")
217
+ ```
218
+
219
+ ---
220
+
221
+ ### 方案 2: 如果测试代码路径错误
222
+ **原因**: 加载了旧的 checkpoint
223
+
224
+ **修复**: 更新测试代码的路径,或者将最新 checkpoint 复制到测试机器
225
+
226
+ ```bash
227
+ # 在训练机器上
228
+ scp -r results_1202_4/checkpoint-60 user@test-machine:/home/wanghongbo06/diffusion-dpo-adv/
229
+ ```
230
+
231
+ ---
232
+
233
+ ### 方案 3: 如果 LoRA 融合有问题
234
+ **原因**: `fuse_lora` 和 `unload_lora_weights` 的交互有问题
235
+
236
+ **修复**: 使用 `set_adapters` 而不是 `fuse_lora`
237
+
238
+ ```python
239
+ # 新的测试代码
240
+ pipe = FluxPipeline.from_pretrained(...).to("cuda")
241
+
242
+ # 加载两个 LoRA(不融合)
243
+ pipe.load_lora_weights("...pytorch_lora_weights_v2.safetensors", adapter_name="sr")
244
+ pipe.load_lora_weights("...adapter_model.safetensors", adapter_name="dpo")
245
+
246
+ # 同时启用两个 adapter
247
+ pipe.set_adapters(["sr", "dpo"], adapter_weights=[1.0, 1.0])
248
+
249
+ # 生成图像
250
+ result_img = generate(pipe, ...)
251
+ ```
252
+
253
+ ---
254
+
255
+ ## 📊 诊断流程图
256
+
257
+ ```
258
+ 开始
259
+
260
+ 检查 1: 对比 checkpoint 权重
261
+ ├─ 权重有变化 → 继续检查 2
262
+ └─ 权重无变化 → 【问题在训练/保存】→ 方案 1
263
+
264
+ 检查 2: 确认测试代码加载的路径
265
+ ├─ 路径正确 → 继续检查 3
266
+ └─ 路径错误 → 【问题在测试代码】→ 方案 2
267
+
268
+ 检查 3: 验证 LoRA 融合逻辑
269
+ ├─ 融合正确 → 继续检查 4
270
+ └─ 融合失败 → 【问题在测试代码】→ 方案 3
271
+
272
+ 检查 4: 检查 checkpoint 文件完整性
273
+ ├─ 文件完整 → 【未知问题,需要更深入调查】
274
+ └─ 文件损坏 → 【问题在保存】→ 方案 1
275
+ ```
276
+
277
+ ---
278
+
279
+ ## 🚀 立即执行
280
+
281
+ **请在训练机器上运行以下命令**:
282
+
283
+ ```bash
284
+ cd <run.sh 的执行目录>
285
+
286
+ # 1. 对比 checkpoint(最重要!)
287
+ python /data2/hongbo.wang/DPO-SR/compare_checkpoints.py \
288
+ results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors \
289
+ results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
290
+
291
+ # 2. 检查单个 checkpoint
292
+ python /data2/hongbo.wang/DPO-SR/inspect_safetensor.py \
293
+ results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
294
+ ```
295
+
296
+ **请将输出结果发给我,我会根据结果进一步诊断!**
297
+
diffusion-dpo-test/DIV2K-val/sobolev-400/0000843-seed-0.png ADDED
diffusion-dpo-test/__pycache__/color_fix.cpython-310.pyc ADDED
Binary file (3.7 kB). View file
 
diffusion-dpo-test/analyze_lora_magnitude.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 分析 LoRA 权重的实际大小,理解为什么效果差异这么大
4
+ """
5
+ import os
6
+ os.environ["HF_HOME"] = "/home/wanghongbo06/.cache/huggingface"
7
+
8
+ import torch
9
+ from safetensors.torch import load_file
10
+ import numpy as np
11
+
12
+ def analyze_lora(path, name=""):
13
+ """分析单个 LoRA 文件"""
14
+ print(f"\n{'='*80}")
15
+ print(f"分析 LoRA: {name}")
16
+ print(f"路径: {path}")
17
+ print(f"{'='*80}")
18
+
19
+ state_dict = load_file(path)
20
+
21
+ # 分离 lora_A 和 lora_B
22
+ lora_a_keys = [k for k in state_dict.keys() if "lora_A" in k]
23
+ lora_b_keys = [k for k in state_dict.keys() if "lora_B" in k]
24
+
25
+ print(f"\n总参数数: {len(state_dict)}")
26
+ print(f" - lora_A: {len(lora_a_keys)} 个")
27
+ print(f" - lora_B: {len(lora_b_keys)} 个")
28
+
29
+ # 分析 lora_A
30
+ print(f"\n--- lora_A 统计 ---")
31
+ a_means = []
32
+ a_maxs = []
33
+ a_stds = []
34
+ for k in lora_a_keys:
35
+ v = state_dict[k].float()
36
+ a_means.append(v.mean().item())
37
+ a_maxs.append(v.abs().max().item())
38
+ a_stds.append(v.std().item())
39
+
40
+ print(f" Mean of means: {np.mean(a_means):.6e}")
41
+ print(f" Max of maxs: {np.max(a_maxs):.6e}")
42
+ print(f" Mean of stds: {np.mean(a_stds):.6e}")
43
+
44
+ # 分析 lora_B
45
+ print(f"\n--- lora_B 统计 ---")
46
+ b_means = []
47
+ b_maxs = []
48
+ b_stds = []
49
+ b_nonzero_ratio = []
50
+ for k in lora_b_keys:
51
+ v = state_dict[k].float()
52
+ b_means.append(v.mean().item())
53
+ b_maxs.append(v.abs().max().item())
54
+ b_stds.append(v.std().item())
55
+ b_nonzero_ratio.append((v.abs() > 1e-10).float().mean().item())
56
+
57
+ print(f" Mean of means: {np.mean(b_means):.6e}")
58
+ print(f" Max of maxs: {np.max(b_maxs):.6e}")
59
+ print(f" Mean of stds: {np.mean(b_stds):.6e}")
60
+ print(f" Avg non-zero ratio: {np.mean(b_nonzero_ratio)*100:.2f}%")
61
+
62
+ # LoRA 的实际贡献 = A @ B
63
+ # 估算 LoRA 对权重的影响大小
64
+ print(f"\n--- LoRA 影响估算 ---")
65
+ print(f" LoRA 输出 ≈ x @ A.T @ B.T")
66
+ print(f" |A| * |B| ≈ {np.max(a_maxs) * np.max(b_maxs):.6e}")
67
+
68
+ # 找出最大的几个 lora_B
69
+ print(f"\n--- 最大的 5 个 lora_B 层 ---")
70
+ b_with_max = [(k, state_dict[k].float().abs().max().item()) for k in lora_b_keys]
71
+ b_with_max.sort(key=lambda x: -x[1])
72
+ for k, v in b_with_max[:5]:
73
+ print(f" {v:.6e}: {k}")
74
+
75
+ return state_dict
76
+
77
+
78
+ def compare_loras(path1, path2, name1="LoRA 1", name2="LoRA 2"):
79
+ """对比两个 LoRA"""
80
+ print(f"\n{'='*80}")
81
+ print(f"对比 {name1} vs {name2}")
82
+ print(f"{'='*80}")
83
+
84
+ sd1 = load_file(path1)
85
+ sd2 = load_file(path2)
86
+
87
+ # 对比 lora_B 的变化
88
+ lora_b_keys = [k for k in sd1.keys() if "lora_B" in k]
89
+
90
+ diffs = []
91
+ for k in lora_b_keys:
92
+ if k in sd2:
93
+ diff = (sd2[k].float() - sd1[k].float()).abs()
94
+ diffs.append({
95
+ 'key': k,
96
+ 'max_diff': diff.max().item(),
97
+ 'mean_diff': diff.mean().item(),
98
+ 'val1_max': sd1[k].float().abs().max().item(),
99
+ 'val2_max': sd2[k].float().abs().max().item(),
100
+ })
101
+
102
+ # 按差异排序
103
+ diffs.sort(key=lambda x: -x['max_diff'])
104
+
105
+ print(f"\n变化最大的 10 个 lora_B 层:")
106
+ print("-" * 100)
107
+ for d in diffs[:10]:
108
+ print(f" max_diff={d['max_diff']:.6e}, {name1}_max={d['val1_max']:.6e}, {name2}_max={d['val2_max']:.6e}")
109
+ print(f" {d['key']}")
110
+
111
+ # 总体统计
112
+ all_max_diffs = [d['max_diff'] for d in diffs]
113
+ print(f"\n总体统计:")
114
+ print(f" 最大差异: {max(all_max_diffs):.6e}")
115
+ print(f" 平均差异: {np.mean(all_max_diffs):.6e}")
116
+ print(f" 差异 > 1e-4 的层数: {sum(1 for d in all_max_diffs if d > 1e-4)}")
117
+ print(f" 差异 > 1e-5 的层数: {sum(1 for d in all_max_diffs if d > 1e-5)}")
118
+
119
+
120
+ def compare_with_sr_lora(sr_path, dpo_path):
121
+ """对比 SR LoRA 和 DPO LoRA 的量级"""
122
+ print(f"\n{'='*80}")
123
+ print(f"对比 SR LoRA 和 DPO LoRA 的量级")
124
+ print(f"{'='*80}")
125
+
126
+ sr_sd = load_file(sr_path)
127
+ dpo_sd = load_file(dpo_path)
128
+
129
+ # SR LoRA 的量级
130
+ sr_b_maxs = []
131
+ for k, v in sr_sd.items():
132
+ if "lora_B" in k or "lora_down" in k: # 不同格式可能用不同命名
133
+ sr_b_maxs.append(v.float().abs().max().item())
134
+
135
+ # DPO LoRA 的量级
136
+ dpo_b_maxs = []
137
+ for k, v in dpo_sd.items():
138
+ if "lora_B" in k:
139
+ dpo_b_maxs.append(v.float().abs().max().item())
140
+
141
+ print(f"\nSR LoRA (lora_B/lora_down):")
142
+ if sr_b_maxs:
143
+ print(f" Max: {max(sr_b_maxs):.6e}")
144
+ print(f" Mean: {np.mean(sr_b_maxs):.6e}")
145
+ else:
146
+ print(f" (没有找到 lora_B 或 lora_down)")
147
+ # 打印所有 key 看看格式
148
+ print(f" SR LoRA keys 示例: {list(sr_sd.keys())[:5]}")
149
+
150
+ print(f"\nDPO LoRA (lora_B):")
151
+ print(f" Max: {max(dpo_b_maxs):.6e}")
152
+ print(f" Mean: {np.mean(dpo_b_maxs):.6e}")
153
+
154
+ if sr_b_maxs and dpo_b_maxs:
155
+ ratio = max(dpo_b_maxs) / max(sr_b_maxs) if max(sr_b_maxs) > 0 else float('inf')
156
+ print(f"\n量级比较:")
157
+ print(f" DPO / SR = {ratio:.4f}")
158
+ if ratio > 1:
159
+ print(f" ⚠️ DPO LoRA 比 SR LoRA 大 {ratio:.1f} 倍!这可能导致效果变差")
160
+ else:
161
+ print(f" DPO LoRA 比 SR LoRA 小 {1/ratio:.1f} 倍")
162
+
163
+
164
+ if __name__ == "__main__":
165
+ # 分析 DPO LoRA checkpoints
166
+ ckpt_15 = "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors"
167
+ ckpt_105 = "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-105/lora_train_unet/adapter_model.safetensors"
168
+ sr_lora = "/home/wanghongbo06/baipurui/CKPTs/FLUX_SR/pytorch_lora_weights_v2.safetensors"
169
+
170
+ # 1. 分析各个 checkpoint
171
+ analyze_lora(ckpt_15, "DPO checkpoint-15")
172
+ analyze_lora(ckpt_105, "DPO checkpoint-105")
173
+
174
+ # 2. 对比 15 vs 105
175
+ compare_loras(ckpt_15, ckpt_105, "ckpt-15", "ckpt-105")
176
+
177
+ # 3. 对比 SR LoRA 和 DPO LoRA 的量级
178
+ compare_with_sr_lora(sr_lora, ckpt_15)
179
+
diffusion-dpo-test/check_lora_keys.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """检查 LoRA safetensor 文件的 key 格式"""
3
+ from safetensors.torch import load_file
4
+ import sys
5
+
6
+ def check_keys(path):
7
+ print(f"\n检查文件: {path}")
8
+ print("=" * 80)
9
+
10
+ state_dict = load_file(path)
11
+
12
+ print(f"总共 {len(state_dict)} 个 key\n")
13
+
14
+ print("前 20 个 key:")
15
+ print("-" * 80)
16
+ for i, key in enumerate(sorted(state_dict.keys())[:20]):
17
+ print(f" {key}")
18
+
19
+ print("\n..." if len(state_dict) > 20 else "")
20
+
21
+ # 检查 key 格式
22
+ print("\nKey 格式分析:")
23
+ print("-" * 80)
24
+
25
+ has_transformer_prefix = any(k.startswith("transformer.") for k in state_dict.keys())
26
+ has_base_model_prefix = any(k.startswith("base_model.") for k in state_dict.keys())
27
+ has_lora_A = any("lora_A" in k for k in state_dict.keys())
28
+ has_lora_B = any("lora_B" in k for k in state_dict.keys())
29
+ has_train_unet = any("train_unet" in k for k in state_dict.keys())
30
+
31
+ print(f" 包含 'transformer.' 前缀: {has_transformer_prefix}")
32
+ print(f" 包含 'base_model.' 前缀: {has_base_model_prefix}")
33
+ print(f" 包含 'lora_A': {has_lora_A}")
34
+ print(f" 包含 'lora_B': {has_lora_B}")
35
+ print(f" 包含 'train_unet': {has_train_unet}")
36
+
37
+ # 显示一个完整的 key 示例
38
+ sample_key = list(state_dict.keys())[0]
39
+ print(f"\n示例 key: {sample_key}")
40
+ print(f"示例 shape: {state_dict[sample_key].shape}")
41
+
42
+ # 检查 Diffusers 期望的格式
43
+ print("\n" + "=" * 80)
44
+ print("Diffusers load_lora_weights 期望的 key 格式:")
45
+ print("-" * 80)
46
+ print(" transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight")
47
+ print(" transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight")
48
+ print("\n您的 key 格式:")
49
+ print(f" {sample_key}")
50
+
51
+ # 判断是否兼容
52
+ print("\n" + "=" * 80)
53
+ if has_train_unet:
54
+ print("❌ 问题: 您的 key 包含 '.train_unet.' 后缀!")
55
+ print(" Diffusers 期望: xxx.lora_A.weight")
56
+ print(" 您的格式: xxx.lora_A.train_unet.weight")
57
+ print("\n 这就是 LoRA 无法加载的原因!")
58
+ elif not has_transformer_prefix:
59
+ print("⚠️ 问题: 您的 key 缺少 'transformer.' 前缀!")
60
+ print(" Diffusers 期望: transformer.xxx.lora_A.weight")
61
+ print(f" 您的格式: {sample_key}")
62
+ else:
63
+ print("✅ Key 格式看起来正确")
64
+
65
+ if __name__ == "__main__":
66
+ paths = [
67
+ "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors",
68
+ "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors",
69
+ ]
70
+
71
+ for path in paths:
72
+ try:
73
+ check_keys(path)
74
+ except Exception as e:
75
+ print(f"❌ 无法读取 {path}: {e}")
76
+
diffusion-dpo-test/color_fix.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # --------------------------------------------------------------------------------
3
+ # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4
+ # --------------------------------------------------------------------------------
5
+ '''
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torch import Tensor
10
+ from torch.nn import functional as F
11
+
12
+ from torchvision.transforms import ToTensor, ToPILImage
13
+
14
+ def adain_color_fix(target: Image, source: Image):
15
+ # Convert images to tensors
16
+ to_tensor = ToTensor()
17
+ target_tensor = to_tensor(target).unsqueeze(0)
18
+ source_tensor = to_tensor(source).unsqueeze(0)
19
+
20
+ # Apply adaptive instance normalization
21
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22
+
23
+ # Convert tensor back to image
24
+ to_image = ToPILImage()
25
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26
+
27
+ return result_image
28
+
29
+ def wavelet_color_fix(target: Image, source: Image):
30
+ # Convert images to tensors
31
+ to_tensor = ToTensor()
32
+ target_tensor = to_tensor(target).unsqueeze(0)
33
+ source_tensor = to_tensor(source).unsqueeze(0)
34
+
35
+ # Apply wavelet reconstruction
36
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37
+
38
+ # Convert tensor back to image
39
+ to_image = ToPILImage()
40
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41
+
42
+ return result_image
43
+
44
+ def calc_mean_std(feat: Tensor, eps=1e-5):
45
+ """Calculate mean and std for adaptive_instance_normalization.
46
+ Args:
47
+ feat (Tensor): 4D tensor.
48
+ eps (float): A small value added to the variance to avoid
49
+ divide-by-zero. Default: 1e-5.
50
+ """
51
+ size = feat.size()
52
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
53
+ b, c = size[:2]
54
+ feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55
+ feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56
+ feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57
+ return feat_mean, feat_std
58
+
59
+ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60
+ """Adaptive instance normalization.
61
+ Adjust the reference features to have the similar color and illuminations
62
+ as those in the degradate features.
63
+ Args:
64
+ content_feat (Tensor): The reference feature.
65
+ style_feat (Tensor): The degradate features.
66
+ """
67
+ size = content_feat.size()
68
+ style_mean, style_std = calc_mean_std(style_feat)
69
+ content_mean, content_std = calc_mean_std(content_feat)
70
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72
+
73
+ def wavelet_blur(image: Tensor, radius: int):
74
+ """
75
+ Apply wavelet blur to the input tensor.
76
+ """
77
+ # input shape: (1, 3, H, W)
78
+ # convolution kernel
79
+ kernel_vals = [
80
+ [0.0625, 0.125, 0.0625],
81
+ [0.125, 0.25, 0.125],
82
+ [0.0625, 0.125, 0.0625],
83
+ ]
84
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85
+ # add channel dimensions to the kernel to make it a 4D tensor
86
+ kernel = kernel[None, None]
87
+ # repeat the kernel across all input channels
88
+ kernel = kernel.repeat(3, 1, 1, 1)
89
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90
+ # apply convolution
91
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
92
+ return output
93
+
94
+ def wavelet_decomposition(image: Tensor, levels=5):
95
+ """
96
+ Apply wavelet decomposition to the input tensor.
97
+ This function only returns the low frequency & the high frequency.
98
+ """
99
+ high_freq = torch.zeros_like(image)
100
+ for i in range(levels):
101
+ radius = 2 ** i
102
+ low_freq = wavelet_blur(image, radius)
103
+ high_freq += (image - low_freq)
104
+ image = low_freq
105
+
106
+ return high_freq, low_freq
107
+
108
+ def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109
+ """
110
+ Apply wavelet decomposition, so that the content will have the same color as the style.
111
+ """
112
+ # calculate the wavelet decomposition of the content feature
113
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114
+ del content_low_freq
115
+ # calculate the wavelet decomposition of the style feature
116
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117
+ del style_high_freq
118
+ # reconstruct the content feature with the style's high frequency
119
+ return content_high_freq + style_low_freq
diffusion-dpo-test/compare.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+
5
+ def compare_images(img_path1, img_path2, diff_output_path=None):
6
+ """
7
+ 逐像素比较两张 RGB 图像的差异。
8
+ 要求两张图分辨率相同、都是 RGB 模式。
9
+
10
+ :param img_path1: 第一张图片路径
11
+ :param img_path2: 第二张图片路径
12
+ :param diff_output_path: 如果不为 None,则保存差异图到该路径
13
+ :return: 一个字典,包含一些差异统计信息
14
+ """
15
+ # 打开图片并转换为 RGB(避免有的图片是 RGBA/灰度模式)
16
+ img1 = Image.open(img_path1).convert("RGB")
17
+ img2 = Image.open(img_path2).convert("RGB")
18
+
19
+ # 检查分辨率是否一致
20
+ if img1.size != img2.size:
21
+ raise ValueError(f"两张图分辨率不同: {img1.size} vs {img2.size}")
22
+
23
+ # 转为 NumPy 数组,形状为 (H, W, 3)
24
+ arr1 = np.array(img1, dtype=np.int16) # 用 int16 以免减法溢出
25
+ arr2 = np.array(img2, dtype=np.int16)
26
+
27
+ # 逐像素逐通道做差,取绝对值
28
+ diff = np.abs(arr1 - arr2) # (H, W, 3)
29
+
30
+ # 每个像素的“差异强度”可以用 RGB 差分的和或均值来表示
31
+ # 这里用每像素的 RGB 差的平均值
32
+ per_pixel_diff = diff.mean(axis=2) # (H, W)
33
+
34
+ # 一些统计信息
35
+ total_pixels = per_pixel_diff.size
36
+ # 差异为0的像素
37
+ same_pixels = np.sum(per_pixel_diff == 0)
38
+ different_pixels = total_pixels - same_pixels
39
+
40
+ max_diff = float(per_pixel_diff.max()) # 单像素最大平均差值
41
+ mean_diff = float(per_pixel_diff.mean()) # 所有像素平均差值
42
+ diff_ratio = different_pixels / total_pixels # 有差异像素占比
43
+
44
+ stats = {
45
+ "total_pixels": int(total_pixels),
46
+ "same_pixels": int(same_pixels),
47
+ "different_pixels": int(different_pixels),
48
+ "different_ratio": diff_ratio, # 0~1 之间
49
+ "max_diff_per_pixel": max_diff, # 0~255
50
+ "mean_diff_per_pixel": mean_diff
51
+ }
52
+
53
+ # 如果需要输出一张差异图
54
+ if diff_output_path is not None:
55
+ # diff 目前是 0~255 范围内的 RGB 差值,可以直接保存成图像看
56
+ diff_img = np.clip(diff, 0, 255).astype(np.uint8)
57
+ diff_image = Image.fromarray(diff_img, mode="RGB")
58
+ diff_image.save(diff_output_path)
59
+ # 也可以考虑把差异增强一下再保存(例如乘个系数)
60
+
61
+ return stats
62
+
63
+
64
+ if __name__ == "__main__":
65
+ img1_path = "./results-test/dpo_scale_ablation/dpo_scale_0.0/0000010-seed-0.png"
66
+ img2_path = "./results-test/dpo_scale_ablation/dpo_scale_1.0/0000010-seed-0.png"
67
+ diff_img_path = "diff.png"
68
+
69
+ stats = compare_images(img1_path, img2_path, diff_output_path=diff_img_path)
70
+
71
+ print("比较结果:")
72
+ for k, v in stats.items():
73
+ print(f"{k}: {v}")
diffusion-dpo-test/compare_checkpoints.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """对比两个 checkpoint 的 safetensor 文件,检查权重是否真的在变化"""
3
+ import sys
4
+ from safetensors.torch import load_file
5
+ import torch
6
+ import os
7
+
8
+ # ==============================
9
+ # 在这里手动填写 checkpoint 路径
10
+ # ==============================
11
+ CHECKPOINT_1 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors"
12
+ CHECKPOINT_2 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors"
13
+ # ==============================
14
+
15
+
16
+ def compare_safetensors(path1, path2):
17
+ print(f"\n{'='*80}")
18
+ print(f"对比两个 checkpoint:")
19
+ print(f" Checkpoint 1: {path1}")
20
+ print(f" Checkpoint 2: {path2}")
21
+ print(f"{'='*80}\n")
22
+
23
+ try:
24
+ state_dict1 = load_file(path1)
25
+ state_dict2 = load_file(path2)
26
+
27
+ # 检查键是否一致
28
+ keys1 = set(state_dict1.keys())
29
+ keys2 = set(state_dict2.keys())
30
+
31
+ if keys1 != keys2:
32
+ print("⚠️ 警告: 两个 checkpoint 的键不一致!")
33
+ print(f" 只在 checkpoint1 中: {keys1 - keys2}")
34
+ print(f" 只在 checkpoint2 中: {keys2 - keys1}")
35
+ return
36
+
37
+ print(f"✅ 两个 checkpoint 都有 {len(keys1)} 个参数张量\n")
38
+
39
+ # 统计差异
40
+ identical_count = 0
41
+ different_count = 0
42
+ max_diff_info = None
43
+ max_diff = 0
44
+
45
+ layer_diffs = {}
46
+
47
+ for key in sorted(keys1):
48
+ tensor1 = state_dict1[key]
49
+ tensor2 = state_dict2[key]
50
+
51
+ diff = (tensor2 - tensor1).float()
52
+ abs_diff = diff.abs()
53
+
54
+ max_abs_diff = abs_diff.max().item()
55
+ mean_abs_diff = abs_diff.mean().item()
56
+
57
+ if max_abs_diff == 0:
58
+ identical_count += 1
59
+ else:
60
+ different_count += 1
61
+
62
+ if max_abs_diff > max_diff:
63
+ max_diff = max_abs_diff
64
+ max_diff_info = {
65
+ 'key': key,
66
+ 'max_diff': max_abs_diff,
67
+ 'mean_diff': mean_abs_diff,
68
+ 'tensor1_max': tensor1.float().abs().max().item(),
69
+ 'tensor2_max': tensor2.float().abs().max().item(),
70
+ }
71
+
72
+ if '.lora_B.' in key:
73
+ layer_name = key.split('.lora_B.')[0]
74
+ if layer_name not in layer_diffs:
75
+ layer_diffs[layer_name] = {
76
+ 'max_diff': max_abs_diff,
77
+ 'mean_diff': mean_abs_diff,
78
+ 'key': key
79
+ }
80
+
81
+ print(f"差异统计:")
82
+ print(f" 完全相同的参数: {identical_count} / {len(keys1)} ({identical_count/len(keys1)*100:.2f}%)")
83
+ print(f" 有变化的参数: {different_count} / {len(keys1)} ({different_count/len(keys1)*100:.2f}%)")
84
+ print()
85
+
86
+ if max_diff_info:
87
+ print(f"最大权重变化:")
88
+ print(f" 层: {max_diff_info['key']}")
89
+ print(f" 最大绝对差异: {max_diff_info['max_diff']:.6e}")
90
+ print(f" 平均绝对差异: {max_diff_info['mean_diff']:.6e}")
91
+ print(f" Checkpoint1 最大值: {max_diff_info['tensor1_max']:.6e}")
92
+ print(f" Checkpoint2 最大值: {max_diff_info['tensor2_max']:.6e}")
93
+ print()
94
+
95
+ key_layers = ['x_embedder', 'transformer_blocks.0', 'transformer_blocks.9',
96
+ 'single_transformer_blocks.30', 'proj_out']
97
+
98
+ print("关键层的 lora_B 权重变化:")
99
+ print("-" * 80)
100
+ for layer_prefix in key_layers:
101
+ matching = [k for k in layer_diffs.keys() if layer_prefix in k]
102
+ if matching:
103
+ for layer_name in matching[:2]:
104
+ info = layer_diffs[layer_name]
105
+ print(f"\n层: {layer_name}")
106
+ print(f" 最大差异: {info['max_diff']:.6e}")
107
+ print(f" 平均差异: {info['mean_diff']:.6e}")
108
+
109
+ key = info['key']
110
+ t1 = state_dict1[key].float()
111
+ t2 = state_dict2[key].float()
112
+ print(f" Checkpoint1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}")
113
+ print(f" Checkpoint2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}")
114
+
115
+ print("\n" + "="*80)
116
+ print("lora_B 权重变化最大的前 10 个层:")
117
+ print("-" * 80)
118
+ sorted_layers = sorted(layer_diffs.items(), key=lambda x: x[1]['max_diff'], reverse=True)
119
+
120
+ for i, (layer_name, info) in enumerate(sorted_layers[:10], 1):
121
+ key = info['key']
122
+ t1 = state_dict1[key].float()
123
+ t2 = state_dict2[key].float()
124
+ print(f"\n{i}. {layer_name}")
125
+ print(f" 最大差异: {info['max_diff']:.6e}, 平均差异: {info['mean_diff']:.6e}")
126
+ print(f" Ckpt1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}")
127
+ print(f" Ckpt2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}")
128
+
129
+ print("\n" + "="*80)
130
+ if different_count == 0:
131
+ print("❌ 严重问题: 两个 checkpoint 完全相同,模型没有学习!")
132
+ elif different_count < len(keys1) * 0.1:
133
+ print(f"⚠️ 警告: 只有 {different_count/len(keys1)*100:.2f}% 的参数在变化,可能存在梯度阻塞")
134
+ else:
135
+ print(f"✅ 正常: {different_count/len(keys1)*100:.2f}% 的参数在变化")
136
+ if max_diff < 1e-6:
137
+ print(f"⚠️ 但是: 最大变化只有 {max_diff:.6e},变化幅度可能太小")
138
+ print("="*80)
139
+
140
+ except Exception as e:
141
+ print(f"❌ 错误: {e}")
142
+ import traceback
143
+ traceback.print_exc()
144
+
145
+
146
+ if __name__ == "__main__":
147
+ compare_safetensors(CHECKPOINT_1, CHECKPOINT_2)
diffusion-dpo-test/data_val/0000009-seed-0.png ADDED
diffusion-dpo-test/data_val/0000010-seed-0.png ADDED
diffusion-dpo-test/fix_lora_keys.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 修复已保存的 LoRA checkpoint 的 key 格式
4
+ 将 PEFT 格式转换为 Diffusers load_lora_weights 期望的格式
5
+
6
+ PEFT 格式: x_embedder.lora_A.train_unet.weight
7
+ Diffusers 格式: transformer.x_embedder.lora_A.weight
8
+ """
9
+ import os
10
+ import sys
11
+ from safetensors.torch import load_file, save_file
12
+
13
+ def fix_lora_keys(input_path, output_path=None):
14
+ """
15
+ 修复 LoRA checkpoint 的 key 格式
16
+
17
+ Args:
18
+ input_path: 输入的 safetensor 文件路径
19
+ output_path: 输出路径,默认覆盖原文件(会先备份)
20
+ """
21
+ if output_path is None:
22
+ output_path = input_path
23
+
24
+ print(f"\n{'='*80}")
25
+ print(f"修复 LoRA Key 格式")
26
+ print(f" 输入: {input_path}")
27
+ print(f" 输出: {output_path}")
28
+ print(f"{'='*80}\n")
29
+
30
+ # 加载原始 state_dict
31
+ state_dict = load_file(input_path)
32
+ print(f"加载了 {len(state_dict)} 个参数\n")
33
+
34
+ # 显示原始 key 格式
35
+ sample_key = list(state_dict.keys())[0]
36
+ print(f"原始 key 格式示例: {sample_key}")
37
+
38
+ # 检查是否需要修复
39
+ needs_fix = False
40
+ if "train_unet" in sample_key:
41
+ needs_fix = True
42
+ print(" ✓ 检测到 '.train_unet.' 后缀,需要移除")
43
+ if not sample_key.startswith("transformer."):
44
+ needs_fix = True
45
+ print(" ✓ 缺少 'transformer.' 前缀,需要添加")
46
+
47
+ if not needs_fix:
48
+ print("\n✅ Key 格式已经正确,无需修复!")
49
+ return
50
+
51
+ # 修复 key 格式
52
+ print("\n开始修复...")
53
+ new_state_dict = {}
54
+ for k, v in state_dict.items():
55
+ new_k = k
56
+ # 1. 移除 base_model.model. 前缀(如果有)
57
+ new_k = new_k.replace("base_model.model.", "")
58
+ # 2. 移除 .train_unet 后缀
59
+ new_k = new_k.replace(".train_unet.", ".")
60
+ # 3. 添加 transformer. 前缀
61
+ if not new_k.startswith("transformer."):
62
+ new_k = "transformer." + new_k
63
+ new_state_dict[new_k] = v
64
+
65
+ # 显示修复后的 key 格式
66
+ new_sample_key = list(new_state_dict.keys())[0]
67
+ print(f"修复后 key 格式示例: {new_sample_key}")
68
+
69
+ # 备份原文件(如果覆盖)
70
+ if output_path == input_path:
71
+ backup_path = input_path + ".backup"
72
+ print(f"\n备份原文件到: {backup_path}")
73
+ os.rename(input_path, backup_path)
74
+
75
+ # 保存修复后的文件
76
+ save_file(new_state_dict, output_path)
77
+ print(f"\n✅ 已保存修复后的文件: {output_path}")
78
+
79
+ # 验证
80
+ print("\n验证修复结果...")
81
+ verify_dict = load_file(output_path)
82
+ verify_key = list(verify_dict.keys())[0]
83
+
84
+ if verify_key.startswith("transformer.") and ".train_unet." not in verify_key:
85
+ print("✅ 验证通过!Key 格式正确")
86
+ else:
87
+ print(f"❌ 验证失败!Key 格式仍有问题: {verify_key}")
88
+
89
+ return new_state_dict
90
+
91
+
92
+ def fix_checkpoint_dir(checkpoint_dir):
93
+ """修复整个 checkpoint 目录"""
94
+ lora_dir = os.path.join(checkpoint_dir, "lora_train_unet")
95
+ adapter_path = os.path.join(lora_dir, "adapter_model.safetensors")
96
+
97
+ if os.path.exists(adapter_path):
98
+ fix_lora_keys(adapter_path)
99
+ else:
100
+ print(f"❌ 找不到文件: {adapter_path}")
101
+
102
+
103
+ if __name__ == "__main__":
104
+ if len(sys.argv) < 2:
105
+ print("用法:")
106
+ print(" python fix_lora_keys.py <checkpoint_path>")
107
+ print(" python fix_lora_keys.py <checkpoint_dir>")
108
+ print()
109
+ print("示例:")
110
+ print(" python fix_lora_keys.py results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors")
111
+ print(" python fix_lora_keys.py results_1202_4/checkpoint-15")
112
+ print()
113
+
114
+ # 默认修复 results_1202_4 下的所有 checkpoint
115
+ base_dir = "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4"
116
+ if os.path.exists(base_dir):
117
+ print(f"自动扫描 {base_dir} 下的所有 checkpoint...")
118
+ for item in sorted(os.listdir(base_dir)):
119
+ if item.startswith("checkpoint-"):
120
+ ckpt_path = os.path.join(base_dir, item)
121
+ fix_checkpoint_dir(ckpt_path)
122
+ else:
123
+ print(f"默认目录 {base_dir} 不存在")
124
+ else:
125
+ path = sys.argv[1]
126
+ if path.endswith(".safetensors"):
127
+ fix_lora_keys(path)
128
+ elif os.path.isdir(path):
129
+ fix_checkpoint_dir(path)
130
+ else:
131
+ print(f"❌ 无效路径: {path}")
132
+
diffusion-dpo-test/inspect_safetensor.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """检查 safetensor 文件的内容"""
3
+ from safetensors.torch import load_file
4
+ import torch
5
+
6
+ # ==========================================================
7
+ # 在这里手动填写 safetensors 文件路径(可填写多个)
8
+ # 示例:
9
+ # SAFETENSOR_PATHS = [
10
+ # r"/path/to/adapter_model1.safetensors",
11
+ # r"/path/to/adapter_model2.safetensors",
12
+ # ]
13
+ # ==========================================================
14
+ SAFETENSOR_PATHS = [
15
+ r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors",
16
+ ]
17
+ # ==========================================================
18
+
19
+
20
+ def inspect_safetensor(path):
21
+ print(f"\n{'='*80}")
22
+ print(f"检查文件: {path}")
23
+ print(f"{'='*80}\n")
24
+
25
+ try:
26
+ state_dict = load_file(path)
27
+
28
+ print(f"总共有 {len(state_dict)} 个参数张量\n")
29
+
30
+ total_params = 0
31
+ zero_params = 0
32
+ non_zero_params = 0
33
+
34
+ layer_stats = {}
35
+
36
+ for key, tensor in state_dict.items():
37
+ num_params = tensor.numel()
38
+ total_params += num_params
39
+
40
+ non_zero_count = (tensor != 0).sum().item()
41
+ zero_count = num_params - non_zero_count
42
+
43
+ if non_zero_count == 0:
44
+ zero_params += num_params
45
+ else:
46
+ non_zero_params += num_params
47
+
48
+ if '.lora_A.' in key or '.lora_B.' in key:
49
+ layer_name = key.split('.lora_')[0]
50
+ if layer_name not in layer_stats:
51
+ layer_stats[layer_name] = {
52
+ 'lora_A': None,
53
+ 'lora_B': None
54
+ }
55
+
56
+ stats_entry = {
57
+ 'mean': tensor.float().mean().item(),
58
+ 'std': tensor.float().std().item(),
59
+ 'max': tensor.float().max().item(),
60
+ 'min': tensor.float().min().item(),
61
+ 'non_zero_ratio': non_zero_count / num_params
62
+ }
63
+
64
+ if '.lora_A.' in key:
65
+ layer_stats[layer_name]['lora_A'] = stats_entry
66
+ else:
67
+ layer_stats[layer_name]['lora_B'] = stats_entry
68
+
69
+ print(f"参数统计:")
70
+ print(f" 总参数数: {total_params:,}")
71
+ print(f" 非零参数: {non_zero_params:,} ({non_zero_params/total_params*100:.2f}%)")
72
+ print(f" 零参数: {zero_params:,} ({zero_params/total_params*100:.2f}%)")
73
+ print()
74
+
75
+ key_layers = ['x_embedder', 'transformer_blocks.0', 'transformer_blocks.9',
76
+ 'single_transformer_blocks.30', 'proj_out']
77
+
78
+ print("关键层统计:")
79
+ print("-" * 80)
80
+ for layer_name in key_layers:
81
+ matching_layers = [k for k in layer_stats.keys() if layer_name in k]
82
+ if matching_layers:
83
+ for full_layer in matching_layers[:3]:
84
+ stats = layer_stats[full_layer]
85
+ print(f"\n层: {full_layer}")
86
+ if stats['lora_A']:
87
+ print(f" lora_A: mean={stats['lora_A']['mean']:.6e}, "
88
+ f"max={stats['lora_A']['max']:.6e}, "
89
+ f"非零比例={stats['lora_A']['non_zero_ratio']*100:.2f}%")
90
+ if stats['lora_B']:
91
+ print(f" lora_B: mean={stats['lora_B']['mean']:.6e}, "
92
+ f"max={stats['lora_B']['max']:.6e}, "
93
+ f"非零比例={stats['lora_B']['non_zero_ratio']*100:.2f}%")
94
+
95
+ print("\n" + "="*80)
96
+ print("lora_B 权重最大的前 5 个层:")
97
+ print("-" * 80)
98
+ lora_b_layers = [(k, v['lora_B']) for k, v in layer_stats.items() if v['lora_B'] is not None]
99
+ lora_b_layers.sort(key=lambda x: abs(x[1]['max']), reverse=True)
100
+
101
+ for i, (layer_name, stats) in enumerate(lora_b_layers[:5], 1):
102
+ print(f"{i}. {layer_name}")
103
+ print(f" mean={stats['mean']:.6e}, max={stats['max']:.6e}, "
104
+ f"非零比例={stats['non_zero_ratio']*100:.2f}%")
105
+
106
+ except Exception as e:
107
+ print(f"❌ 错误: {e}")
108
+ import traceback
109
+ traceback.print_exc()
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # 遍历手动填写的文件路径
114
+ for path in SAFETENSOR_PATHS:
115
+ inspect_safetensor(path)
diffusion-dpo-test/metrics.json ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "summary": {
3
+ "avg_inference_time_sec": 23.599896386852894,
4
+ "std_inference_time_sec": 6.796912376816178,
5
+ "min_inference_time_sec": 11.098074495792389,
6
+ "max_inference_time_sec": 28.76287134224549,
7
+ "median_inference_time_sec": 26.921020144131035,
8
+ "p95_inference_time_sec": 28.102163634379394,
9
+ "p99_inference_time_sec": 28.270882021300498,
10
+ "throughput_single_gpu_per_sec": 0.042373067390121394,
11
+ "throughput_parallel_per_sec": 0.08250964031858578,
12
+ "peak_memory_mb": 33811.7392578125,
13
+ "peak_memory_gb": 33.01927661895752,
14
+ "total_images": 100,
15
+ "warmup_images": 8,
16
+ "measured_images": 92,
17
+ "model_load_time_sec": 52.43063974380493,
18
+ "inference_wall_time_sec": 1211.9795894622803,
19
+ "total_time_sec": 1264.4102292060852,
20
+ "num_gpus": 2
21
+ },
22
+ "per_gpu_metrics": {
23
+ "1": {
24
+ "inference_times": [
25
+ 27.854881670325994,
26
+ 27.943492013029754,
27
+ 27.888656420167536,
28
+ 27.804196048993617,
29
+ 27.924615799915045,
30
+ 27.952690774109215,
31
+ 27.882505219895393,
32
+ 27.896253335289657,
33
+ 28.09671812877059,
34
+ 27.85268287314102,
35
+ 27.742382244672626,
36
+ 28.76287134224549,
37
+ 28.061799827963114,
38
+ 28.054252550937235,
39
+ 27.92581091215834,
40
+ 27.954700701870024,
41
+ 27.952801280654967,
42
+ 27.932732206769288,
43
+ 28.13525341823697,
44
+ 27.90626529790461,
45
+ 28.22222373681143,
46
+ 27.82774563319981,
47
+ 27.953424714971334,
48
+ 28.111765013076365,
49
+ 27.881029167678207,
50
+ 27.89013520721346,
51
+ 27.974640298169106,
52
+ 28.015457140281796,
53
+ 28.10881925234571,
54
+ 27.839136187918484,
55
+ 27.980245498009026,
56
+ 27.941782257054,
57
+ 27.858478610869497,
58
+ 18.717311061918736,
59
+ 11.598434458952397,
60
+ 11.579710375983268,
61
+ 11.584534626919776,
62
+ 11.59768055099994,
63
+ 11.199870459735394,
64
+ 11.196961861103773,
65
+ 11.191290821880102,
66
+ 11.286846159026027,
67
+ 11.2286821231246,
68
+ 11.219772449228913,
69
+ 11.226453838404268,
70
+ 11.211608635727316
71
+ ],
72
+ "warmup_time": 115.76587492786348,
73
+ "peak_memory_mb": 33810.9892578125,
74
+ "allocated_memory_mb": 33196.46240234375,
75
+ "reserved_memory_mb": 34506.0,
76
+ "total_images": 50,
77
+ "avg_inference_time": 23.434121787122898,
78
+ "std_inference_time": 7.309697334208333,
79
+ "throughput": 0.04267281740208,
80
+ "memory_efficiency": 96.20489886496189
81
+ },
82
+ "0": {
83
+ "inference_times": [
84
+ 26.802028878591955,
85
+ 26.797191261779517,
86
+ 26.82992110401392,
87
+ 26.958114746958017,
88
+ 26.761777761392295,
89
+ 26.84792256169021,
90
+ 26.746575728990138,
91
+ 27.076817566063255,
92
+ 26.943395509850234,
93
+ 26.85071571683511,
94
+ 26.841394792776555,
95
+ 27.631777914240956,
96
+ 26.833319212775677,
97
+ 26.806214389856905,
98
+ 26.8738283761777,
99
+ 26.895515635609627,
100
+ 26.940260547678918,
101
+ 27.09848231682554,
102
+ 26.858372538816184,
103
+ 26.918541864026338,
104
+ 26.94270030176267,
105
+ 26.95117418281734,
106
+ 26.760037765838206,
107
+ 26.82909763790667,
108
+ 26.831056244205683,
109
+ 26.92349842423573,
110
+ 26.80812106281519,
111
+ 26.730416806880385,
112
+ 27.080423870123923,
113
+ 27.055579679086804,
114
+ 27.27255716919899,
115
+ 27.117452350910753,
116
+ 26.751979127991945,
117
+ 26.876946676988155,
118
+ 26.70633071102202,
119
+ 26.0279257488437,
120
+ 25.012685468886048,
121
+ 11.098074495792389,
122
+ 11.139604906085879,
123
+ 11.128950875252485,
124
+ 11.235403429716825,
125
+ 11.146971406880766,
126
+ 11.146004664245993,
127
+ 11.10717493435368,
128
+ 11.114083136897534,
129
+ 11.114445879124105
130
+ ],
131
+ "warmup_time": 111.6877530622296,
132
+ "peak_memory_mb": 33811.7392578125,
133
+ "allocated_memory_mb": 33197.21240234375,
134
+ "reserved_memory_mb": 34506.0,
135
+ "total_images": 50,
136
+ "avg_inference_time": 23.76567098658289,
137
+ "std_inference_time": 6.237739828068353,
138
+ "throughput": 0.042077499118983785,
139
+ "memory_efficiency": 96.20707239999928
140
+ }
141
+ }
142
+ }
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000263-seed-0.png ADDED
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000463-seed-0.png ADDED
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000563-seed-0.png ADDED
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000763-seed-0.png ADDED
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000863-seed-0.png ADDED
diffusion-dpo-test/results-test/DrealSR/sony_160_x4.png ADDED
diffusion-dpo-test/results-test/DrealSR/sony_189_x4.png ADDED
diffusion-dpo-test/src/flux/__pycache__/block.cpython-310.pyc ADDED
Binary file (5.96 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/block.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/condition.cpython-310.pyc ADDED
Binary file (3.32 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/condition.cpython-311.pyc ADDED
Binary file (5.83 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/generate.cpython-310.pyc ADDED
Binary file (5.79 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/generate.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-311.pyc ADDED
Binary file (5.13 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-310.pyc ADDED
Binary file (1.42 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-311.pyc ADDED
Binary file (2.56 kB). View file
 
diffusion-dpo-test/src/flux/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (4.26 kB). View file