Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- diffusion-dpo-ocr/check_video_resolution.py +194 -0
- diffusion-dpo-ocr/prepare_roadtext.py +625 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_BSRGAN.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_DP2O-SR.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_DiT4SR.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_DiffBIR.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_FaithDiff.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_Ours.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_Real-ESRGAN.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_SUPSR.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_SeeSR.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_StableSR.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_SwinIR.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_gt.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_sample00.json +17 -0
- diffusion-dpo-ocr/results/roadtext_eval_results_zoomlr.json +17 -0
- diffusion-dpo-ocr/roadtext_eval_results_output.json +17 -0
- diffusion-dpo-ocr/test_roadtext.py +514 -0
- diffusion-dpo-ocr/verify_roadtext_annotations.py +223 -0
- diffusion-dpo-test/DIAGNOSTIC_CHECKLIST.md +297 -0
- diffusion-dpo-test/DIV2K-val/sobolev-400/0000843-seed-0.png +0 -0
- diffusion-dpo-test/__pycache__/color_fix.cpython-310.pyc +0 -0
- diffusion-dpo-test/analyze_lora_magnitude.py +179 -0
- diffusion-dpo-test/check_lora_keys.py +76 -0
- diffusion-dpo-test/color_fix.py +119 -0
- diffusion-dpo-test/compare.py +73 -0
- diffusion-dpo-test/compare_checkpoints.py +147 -0
- diffusion-dpo-test/data_val/0000009-seed-0.png +0 -0
- diffusion-dpo-test/data_val/0000010-seed-0.png +0 -0
- diffusion-dpo-test/fix_lora_keys.py +132 -0
- diffusion-dpo-test/inspect_safetensor.py +115 -0
- diffusion-dpo-test/metrics.json +142 -0
- diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000263-seed-0.png +0 -0
- diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000463-seed-0.png +0 -0
- diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000563-seed-0.png +0 -0
- diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000763-seed-0.png +0 -0
- diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000863-seed-0.png +0 -0
- diffusion-dpo-test/results-test/DrealSR/sony_160_x4.png +0 -0
- diffusion-dpo-test/results-test/DrealSR/sony_189_x4.png +0 -0
- diffusion-dpo-test/src/flux/__pycache__/block.cpython-310.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/block.cpython-311.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/condition.cpython-310.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/condition.cpython-311.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/generate.cpython-310.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/generate.cpython-311.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-310.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-311.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-310.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-311.pyc +0 -0
- diffusion-dpo-test/src/flux/__pycache__/transformer.cpython-310.pyc +0 -0
diffusion-dpo-ocr/check_video_resolution.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
检查 RoadText1K Videos 目录下所有视频的分辨率
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from collections import Counter
|
| 8 |
+
import cv2
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ============================================================================
|
| 13 |
+
# 配置参数 - 请修改这里
|
| 14 |
+
# ============================================================================
|
| 15 |
+
CONFIG = {
|
| 16 |
+
# RoadText1K Videos 目录
|
| 17 |
+
'videos_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k/Videos',
|
| 18 |
+
|
| 19 |
+
# 目标分辨率 (可选,用于检查是否匹配)
|
| 20 |
+
'target_resolution': (1920, 1080), # 例如: (1920, 1080) 或 None 只统计
|
| 21 |
+
|
| 22 |
+
# 检查哪个数据集
|
| 23 |
+
'split': 'test', # 'test', 'train', 'val', 或 'all' 检查全部
|
| 24 |
+
}
|
| 25 |
+
# ============================================================================
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_video_resolution(video_path: Path) -> tuple:
|
| 29 |
+
"""
|
| 30 |
+
获取视频分辨率
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
(width, height) 或 None 如果无法读取
|
| 34 |
+
"""
|
| 35 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 36 |
+
if not cap.isOpened():
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 40 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 41 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 42 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 43 |
+
|
| 44 |
+
cap.release()
|
| 45 |
+
|
| 46 |
+
return {
|
| 47 |
+
'width': width,
|
| 48 |
+
'height': height,
|
| 49 |
+
'fps': fps,
|
| 50 |
+
'frame_count': frame_count,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
videos_dir = Path(CONFIG['videos_dir'])
|
| 56 |
+
split = CONFIG['split']
|
| 57 |
+
target_res = CONFIG['target_resolution']
|
| 58 |
+
|
| 59 |
+
print("="*60)
|
| 60 |
+
print("Video Resolution Checker")
|
| 61 |
+
print("="*60)
|
| 62 |
+
print(f"Videos directory: {videos_dir}")
|
| 63 |
+
print(f"Split: {split}")
|
| 64 |
+
if target_res:
|
| 65 |
+
print(f"Target resolution: {target_res[0]}x{target_res[1]}")
|
| 66 |
+
print()
|
| 67 |
+
|
| 68 |
+
# 确定要检查的目录
|
| 69 |
+
if split == 'all':
|
| 70 |
+
check_dirs = ['test', 'train', 'val']
|
| 71 |
+
else:
|
| 72 |
+
check_dirs = [split]
|
| 73 |
+
|
| 74 |
+
all_resolutions = []
|
| 75 |
+
resolution_stats = Counter()
|
| 76 |
+
mismatched_videos = []
|
| 77 |
+
|
| 78 |
+
total_videos = 0
|
| 79 |
+
|
| 80 |
+
for split_name in check_dirs:
|
| 81 |
+
split_dir = videos_dir / split_name
|
| 82 |
+
|
| 83 |
+
if not split_dir.exists():
|
| 84 |
+
print(f"Warning: Directory not found: {split_dir}")
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
print(f"\nChecking {split_name}...")
|
| 88 |
+
|
| 89 |
+
# 查找所有子目录中的视频文件(支持嵌套结构)
|
| 90 |
+
video_files = []
|
| 91 |
+
for subdir in sorted(split_dir.iterdir()):
|
| 92 |
+
if subdir.is_dir():
|
| 93 |
+
# 递归查找子目录中的视频
|
| 94 |
+
video_files.extend(subdir.glob('*.mp4'))
|
| 95 |
+
video_files.extend(subdir.glob('*.avi'))
|
| 96 |
+
elif subdir.suffix in ['.mp4', '.avi']:
|
| 97 |
+
# 视频文件直接在 split 目录下
|
| 98 |
+
video_files.append(subdir)
|
| 99 |
+
|
| 100 |
+
video_files = sorted(video_files)
|
| 101 |
+
|
| 102 |
+
print(f"Found {len(video_files)} videos")
|
| 103 |
+
|
| 104 |
+
for video_path in tqdm(video_files, desc=f"Processing {split_name}"):
|
| 105 |
+
info = get_video_resolution(video_path)
|
| 106 |
+
|
| 107 |
+
if info is None:
|
| 108 |
+
print(f"Warning: Cannot read {video_path.name}")
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
width = info['width']
|
| 112 |
+
height = info['height']
|
| 113 |
+
resolution = (width, height)
|
| 114 |
+
|
| 115 |
+
all_resolutions.append({
|
| 116 |
+
'file': video_path.name,
|
| 117 |
+
'split': split_name,
|
| 118 |
+
'width': width,
|
| 119 |
+
'height': height,
|
| 120 |
+
'fps': info['fps'],
|
| 121 |
+
'frame_count': info['frame_count'],
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
resolution_stats[resolution] += 1
|
| 125 |
+
|
| 126 |
+
# 检查是否匹配目标分辨率
|
| 127 |
+
if target_res:
|
| 128 |
+
if resolution != target_res:
|
| 129 |
+
mismatched_videos.append({
|
| 130 |
+
'file': video_path.name,
|
| 131 |
+
'split': split_name,
|
| 132 |
+
'resolution': f"{width}x{height}",
|
| 133 |
+
'expected': f"{target_res[0]}x{target_res[1]}",
|
| 134 |
+
})
|
| 135 |
+
|
| 136 |
+
total_videos += 1
|
| 137 |
+
|
| 138 |
+
# 打印统计结果
|
| 139 |
+
print("\n" + "="*60)
|
| 140 |
+
print("STATISTICS")
|
| 141 |
+
print("="*60)
|
| 142 |
+
print(f"Total videos checked: {total_videos}")
|
| 143 |
+
print(f"Unique resolutions: {len(resolution_stats)}")
|
| 144 |
+
print()
|
| 145 |
+
|
| 146 |
+
print("Resolution distribution:")
|
| 147 |
+
print("-" * 60)
|
| 148 |
+
for (width, height), count in resolution_stats.most_common():
|
| 149 |
+
percentage = count / total_videos * 100
|
| 150 |
+
print(f" {width:4d} x {height:4d}: {count:4d} videos ({percentage:5.1f}%)")
|
| 151 |
+
|
| 152 |
+
# 打印最常见的分辨率
|
| 153 |
+
if resolution_stats:
|
| 154 |
+
most_common = resolution_stats.most_common(1)[0]
|
| 155 |
+
print(f"\nMost common resolution: {most_common[0][0]}x{most_common[0][1]} ({most_common[1]} videos)")
|
| 156 |
+
|
| 157 |
+
# 如果有目标分辨率,显示不匹配的
|
| 158 |
+
if target_res:
|
| 159 |
+
print("\n" + "="*60)
|
| 160 |
+
print("RESOLUTION MATCHING")
|
| 161 |
+
print("="*60)
|
| 162 |
+
if mismatched_videos:
|
| 163 |
+
print(f"Videos NOT matching target resolution: {len(mismatched_videos)}")
|
| 164 |
+
print("\nMismatched videos (first 20):")
|
| 165 |
+
for item in mismatched_videos[:20]:
|
| 166 |
+
print(f" {item['split']}/{item['file']}: {item['resolution']} (expected {item['expected']})")
|
| 167 |
+
if len(mismatched_videos) > 20:
|
| 168 |
+
print(f" ... and {len(mismatched_videos) - 20} more")
|
| 169 |
+
else:
|
| 170 |
+
print(f"✓ All videos match target resolution {target_res[0]}x{target_res[1]}")
|
| 171 |
+
|
| 172 |
+
# 保存详细结果
|
| 173 |
+
output_file = videos_dir.parent / 'video_resolution_report.json'
|
| 174 |
+
import json
|
| 175 |
+
report = {
|
| 176 |
+
'total_videos': total_videos,
|
| 177 |
+
'resolution_distribution': {f"{w}x{h}": count for (w, h), count in resolution_stats.items()},
|
| 178 |
+
'videos': all_resolutions,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
if target_res:
|
| 182 |
+
report['target_resolution'] = f"{target_res[0]}x{target_res[1]}"
|
| 183 |
+
report['mismatched_count'] = len(mismatched_videos)
|
| 184 |
+
report['mismatched_videos'] = mismatched_videos
|
| 185 |
+
|
| 186 |
+
with open(output_file, 'w') as f:
|
| 187 |
+
json.dump(report, f, indent=2)
|
| 188 |
+
|
| 189 |
+
print(f"\nDetailed report saved to: {output_file}")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == '__main__':
|
| 193 |
+
main()
|
| 194 |
+
|
diffusion-dpo-ocr/prepare_roadtext.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RoadText1K 预处理脚本
|
| 3 |
+
从视频中提取帧,resize 到 512x512,合并 Localisation 和 Text_Transcription 标注
|
| 4 |
+
|
| 5 |
+
使用流程:
|
| 6 |
+
1. 运行此脚本生成 GT images (512x512) 和合并后的标注
|
| 7 |
+
2. 用你的方式生成 LR 和 SR images
|
| 8 |
+
3. 运行 test_roadtext.py 评估 OCR 性能
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import random
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, List, Tuple
|
| 16 |
+
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import numpy as np
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import cv2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ============================================================================
|
| 24 |
+
# 配置参数 - 请修改这里
|
| 25 |
+
# ============================================================================
|
| 26 |
+
CONFIG = {
|
| 27 |
+
# RoadText1K 根目录
|
| 28 |
+
'roadtext_root': '/home/wanghongbo06/baipurui/DATA/RoadText1k',
|
| 29 |
+
|
| 30 |
+
# 使用哪个数据集 (test/train/val)
|
| 31 |
+
'split': 'test',
|
| 32 |
+
|
| 33 |
+
# 输出目录
|
| 34 |
+
'output_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop',
|
| 35 |
+
|
| 36 |
+
# Crop 尺寸 (从原图中 crop 这个大小的区域)
|
| 37 |
+
'crop_size': 512,
|
| 38 |
+
|
| 39 |
+
# Crop 策略: 'center', 'random', 'text_center' (以文本区域为中心)
|
| 40 |
+
'crop_strategy': 'text_center',
|
| 41 |
+
|
| 42 |
+
# 最小文本框保留比例 (文本框至少要有这么多比例在 crop 区域内才保留)
|
| 43 |
+
# 1.0 表示只保留完全在 crop 区域内的文本框(推荐,避免坐标问题)
|
| 44 |
+
# 0.7 表示文本框至少 70% 在 crop 区域内(但可能导致坐标超出边界)
|
| 45 |
+
'min_box_overlap': 1.0,
|
| 46 |
+
|
| 47 |
+
# 最小文本框数量 (crop 后至少要有这么多有效文本框)
|
| 48 |
+
'min_text_boxes': 1,
|
| 49 |
+
|
| 50 |
+
# 最终图片数量限制
|
| 51 |
+
'max_frames': 1000,
|
| 52 |
+
|
| 53 |
+
# 随机种子
|
| 54 |
+
'seed': 42,
|
| 55 |
+
}
|
| 56 |
+
# ============================================================================
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_localisation_annotations(root_dir: str) -> Dict:
|
| 60 |
+
"""加载所有 Localisation 标注"""
|
| 61 |
+
loc_dir = Path(root_dir) / 'Ground_truths' / 'Localisation'
|
| 62 |
+
all_annotations = {}
|
| 63 |
+
|
| 64 |
+
print("Loading Localisation annotations...")
|
| 65 |
+
json_files = sorted(loc_dir.glob('*.json'))
|
| 66 |
+
print(f"Found {len(json_files)} JSON files")
|
| 67 |
+
|
| 68 |
+
for json_file in json_files:
|
| 69 |
+
try:
|
| 70 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
| 71 |
+
data = json.load(f)
|
| 72 |
+
# data 应该是一个列表
|
| 73 |
+
if isinstance(data, list):
|
| 74 |
+
for item in data:
|
| 75 |
+
img_name = item.get('name', '')
|
| 76 |
+
if img_name:
|
| 77 |
+
all_annotations[img_name] = item
|
| 78 |
+
else:
|
| 79 |
+
print(f"Warning: {json_file.name} is not a list format")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Error loading {json_file.name}: {e}")
|
| 82 |
+
|
| 83 |
+
print(f"Loaded {len(all_annotations)} image annotations")
|
| 84 |
+
return all_annotations
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def load_text_transcriptions(root_dir: str) -> Dict:
|
| 88 |
+
"""加载所有 Text_Transcription 标注"""
|
| 89 |
+
text_dir = Path(root_dir) / 'Ground_truths' / 'Text_Transcription'
|
| 90 |
+
all_texts = {}
|
| 91 |
+
|
| 92 |
+
print("Loading Text_Transcription annotations...")
|
| 93 |
+
# 支持 *.json 和 *.json.json (双扩展名)
|
| 94 |
+
json_files = sorted(list(text_dir.glob('*.json')) + list(text_dir.glob('*.json.json')))
|
| 95 |
+
print(f"Found {len(json_files)} JSON files")
|
| 96 |
+
|
| 97 |
+
for json_file in json_files:
|
| 98 |
+
try:
|
| 99 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
| 100 |
+
data = json.load(f)
|
| 101 |
+
# data 应该是一个字典 {video_name: {label_id: text}}
|
| 102 |
+
if isinstance(data, dict):
|
| 103 |
+
for video_name, texts in data.items():
|
| 104 |
+
if video_name not in all_texts:
|
| 105 |
+
all_texts[video_name] = {}
|
| 106 |
+
if isinstance(texts, dict):
|
| 107 |
+
all_texts[video_name].update(texts)
|
| 108 |
+
else:
|
| 109 |
+
print(f"Warning: {json_file.name} is not a dict format")
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error loading {json_file.name}: {e}")
|
| 112 |
+
|
| 113 |
+
print(f"Loaded texts for {len(all_texts)} videos")
|
| 114 |
+
return all_texts
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_box_bounds(box: List[float]) -> Tuple[float, float, float, float]:
|
| 118 |
+
"""从多边形获取边界框 (x1, y1, x2, y2)"""
|
| 119 |
+
# box2d 格式 [x1,y1,x2,y1,x2,y2,x1,y2] 或其他多边形格式
|
| 120 |
+
xs = [box[i] for i in range(0, len(box), 2)]
|
| 121 |
+
ys = [box[i] for i in range(1, len(box), 2)]
|
| 122 |
+
return min(xs), min(ys), max(xs), max(ys)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def is_box_fully_inside(box: List[float], crop_x: int, crop_y: int, crop_size: int) -> bool:
|
| 126 |
+
"""检查框是否完全在 crop 区域内"""
|
| 127 |
+
x1, y1, x2, y2 = get_box_bounds(box)
|
| 128 |
+
cx1, cy1 = crop_x, crop_y
|
| 129 |
+
cx2, cy2 = crop_x + crop_size, crop_y + crop_size
|
| 130 |
+
return x1 >= cx1 and y1 >= cy1 and x2 <= cx2 and y2 <= cy2
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def calc_box_overlap(box: List[float], crop_x: int, crop_y: int, crop_size: int) -> float:
|
| 134 |
+
"""
|
| 135 |
+
计算文本框与 crop 区域的重叠比例
|
| 136 |
+
返回值: 0.0 - 1.0,表示文本框有多少比例在 crop 区域内
|
| 137 |
+
"""
|
| 138 |
+
x1, y1, x2, y2 = get_box_bounds(box)
|
| 139 |
+
box_area = (x2 - x1) * (y2 - y1)
|
| 140 |
+
if box_area <= 0:
|
| 141 |
+
return 0.0
|
| 142 |
+
|
| 143 |
+
# crop 区域边界
|
| 144 |
+
cx1, cy1 = crop_x, crop_y
|
| 145 |
+
cx2, cy2 = crop_x + crop_size, crop_y + crop_size
|
| 146 |
+
|
| 147 |
+
# 计算交集
|
| 148 |
+
inter_x1 = max(x1, cx1)
|
| 149 |
+
inter_y1 = max(y1, cy1)
|
| 150 |
+
inter_x2 = min(x2, cx2)
|
| 151 |
+
inter_y2 = min(y2, cy2)
|
| 152 |
+
|
| 153 |
+
if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
|
| 154 |
+
return 0.0
|
| 155 |
+
|
| 156 |
+
inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
|
| 157 |
+
return inter_area / box_area
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def clip_polygon_to_crop(poly: List[float], crop_x: int, crop_y: int, crop_size: int) -> List[float]:
|
| 161 |
+
"""
|
| 162 |
+
将多边形坐标转换到 crop 坐标系
|
| 163 |
+
|
| 164 |
+
当 min_box_overlap = 1.0 时,框完全在 crop 内,只需要平移
|
| 165 |
+
当 min_box_overlap < 1.0 时,需要将超出边界的坐标裁剪到边界
|
| 166 |
+
"""
|
| 167 |
+
clipped = []
|
| 168 |
+
for i in range(0, len(poly), 2):
|
| 169 |
+
x = poly[i] - crop_x # 先平移
|
| 170 |
+
y = poly[i+1] - crop_y
|
| 171 |
+
|
| 172 |
+
# 裁剪到 [0, crop_size] 范围内(安全措施)
|
| 173 |
+
x = max(0, min(x, crop_size))
|
| 174 |
+
y = max(0, min(y, crop_size))
|
| 175 |
+
|
| 176 |
+
clipped.append(x)
|
| 177 |
+
clipped.append(y)
|
| 178 |
+
|
| 179 |
+
return clipped
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def find_best_crop_position(
|
| 183 |
+
polygons: List[List[float]],
|
| 184 |
+
img_w: int,
|
| 185 |
+
img_h: int,
|
| 186 |
+
crop_size: int,
|
| 187 |
+
strategy: str = 'text_center',
|
| 188 |
+
min_overlap: float = 0.7
|
| 189 |
+
) -> Tuple[int, int, List[int]]:
|
| 190 |
+
"""
|
| 191 |
+
找到最佳的 crop 位置
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
polygons: 文本框多边形列表
|
| 195 |
+
img_w, img_h: 原图尺寸
|
| 196 |
+
crop_size: crop 尺寸
|
| 197 |
+
strategy: crop 策略
|
| 198 |
+
min_overlap: 最小重叠比例,文本框至少这么多比例在 crop 内才保留
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
(crop_x, crop_y, valid_box_indices)
|
| 202 |
+
"""
|
| 203 |
+
if not polygons:
|
| 204 |
+
crop_x = max(0, (img_w - crop_size) // 2)
|
| 205 |
+
crop_y = max(0, (img_h - crop_size) // 2)
|
| 206 |
+
return crop_x, crop_y, []
|
| 207 |
+
|
| 208 |
+
if strategy == 'center':
|
| 209 |
+
crop_x = max(0, (img_w - crop_size) // 2)
|
| 210 |
+
crop_y = max(0, (img_h - crop_size) // 2)
|
| 211 |
+
|
| 212 |
+
elif strategy == 'text_center':
|
| 213 |
+
# 计算所有文本框的中心,以平均中心为 crop 中心
|
| 214 |
+
centers_x = []
|
| 215 |
+
centers_y = []
|
| 216 |
+
for poly in polygons:
|
| 217 |
+
x1, y1, x2, y2 = get_box_bounds(poly)
|
| 218 |
+
centers_x.append((x1 + x2) / 2)
|
| 219 |
+
centers_y.append((y1 + y2) / 2)
|
| 220 |
+
|
| 221 |
+
avg_cx = sum(centers_x) / len(centers_x)
|
| 222 |
+
avg_cy = sum(centers_y) / len(centers_y)
|
| 223 |
+
|
| 224 |
+
crop_x = int(avg_cx - crop_size / 2)
|
| 225 |
+
crop_y = int(avg_cy - crop_size / 2)
|
| 226 |
+
|
| 227 |
+
# 边界检查
|
| 228 |
+
crop_x = max(0, min(crop_x, img_w - crop_size))
|
| 229 |
+
crop_y = max(0, min(crop_y, img_h - crop_size))
|
| 230 |
+
|
| 231 |
+
elif strategy == 'random':
|
| 232 |
+
max_x = max(0, img_w - crop_size)
|
| 233 |
+
max_y = max(0, img_h - crop_size)
|
| 234 |
+
crop_x = random.randint(0, max_x) if max_x > 0 else 0
|
| 235 |
+
crop_y = random.randint(0, max_y) if max_y > 0 else 0
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
crop_x = max(0, (img_w - crop_size) // 2)
|
| 239 |
+
crop_y = max(0, (img_h - crop_size) // 2)
|
| 240 |
+
|
| 241 |
+
# 找出哪些文本框的重叠比例 >= min_overlap
|
| 242 |
+
valid_indices = []
|
| 243 |
+
for i, poly in enumerate(polygons):
|
| 244 |
+
if min_overlap >= 0.99:
|
| 245 |
+
# 严格模式:只保留完全在 crop 区域内的框
|
| 246 |
+
if is_box_fully_inside(poly, crop_x, crop_y, crop_size):
|
| 247 |
+
valid_indices.append(i)
|
| 248 |
+
else:
|
| 249 |
+
# 宽松模式:保留 overlap >= min_overlap 的框
|
| 250 |
+
overlap = calc_box_overlap(poly, crop_x, crop_y, crop_size)
|
| 251 |
+
if overlap >= min_overlap:
|
| 252 |
+
valid_indices.append(i)
|
| 253 |
+
|
| 254 |
+
return crop_x, crop_y, valid_indices
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def extract_frame_with_crop(
|
| 258 |
+
video_path: Path,
|
| 259 |
+
frame_idx: int,
|
| 260 |
+
output_dir: Path,
|
| 261 |
+
crop_x: int,
|
| 262 |
+
crop_y: int,
|
| 263 |
+
crop_size: int,
|
| 264 |
+
) -> Tuple[str, bool]:
|
| 265 |
+
"""
|
| 266 |
+
从视频中提取指定帧并 crop
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
(saved_filename, success)
|
| 270 |
+
"""
|
| 271 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 272 |
+
if not cap.isOpened():
|
| 273 |
+
return '', False
|
| 274 |
+
|
| 275 |
+
# 跳到指定帧
|
| 276 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
| 277 |
+
ret, frame = cap.read()
|
| 278 |
+
cap.release()
|
| 279 |
+
|
| 280 |
+
if not ret:
|
| 281 |
+
return '', False
|
| 282 |
+
|
| 283 |
+
# 转换为 RGB
|
| 284 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 285 |
+
img = Image.fromarray(frame_rgb)
|
| 286 |
+
|
| 287 |
+
# Crop
|
| 288 |
+
img_cropped = img.crop((crop_x, crop_y, crop_x + crop_size, crop_y + crop_size))
|
| 289 |
+
|
| 290 |
+
# 保存
|
| 291 |
+
video_name = video_path.stem
|
| 292 |
+
img_filename = f"{video_name}-{frame_idx:07d}.png"
|
| 293 |
+
img_path = output_dir / img_filename
|
| 294 |
+
img_cropped.save(img_path)
|
| 295 |
+
|
| 296 |
+
return img_filename, True
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def adjust_polygon_for_crop(poly: List[float], crop_x: int, crop_y: int) -> List[float]:
|
| 300 |
+
"""调整多边形坐标到 crop 后的坐标系"""
|
| 301 |
+
adjusted = []
|
| 302 |
+
for i in range(0, len(poly), 2):
|
| 303 |
+
adjusted.append(poly[i] - crop_x)
|
| 304 |
+
adjusted.append(poly[i+1] - crop_y)
|
| 305 |
+
return adjusted
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def merge_annotations(loc_ann: Dict, text_dict: Dict) -> Dict:
|
| 309 |
+
"""
|
| 310 |
+
合并 Localisation 和 Text_Transcription
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
loc_ann: Localisation 标注项
|
| 314 |
+
text_dict: Text_Transcription 字典 {video_name: {label_id: text}}
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
合并后的标注: {
|
| 318 |
+
'polygons': [[x1,y1,x2,y2,x3,y3,x4,y4], ...],
|
| 319 |
+
'texts': ['text1', 'text2', ...],
|
| 320 |
+
'ignore': [False, False, ...]
|
| 321 |
+
}
|
| 322 |
+
"""
|
| 323 |
+
video_name = loc_ann.get('videoName', '')
|
| 324 |
+
labels = loc_ann.get('labels') or [] # 处理 None 的情况
|
| 325 |
+
|
| 326 |
+
# 获取该视频的文本字典
|
| 327 |
+
video_texts = text_dict.get(video_name, {})
|
| 328 |
+
|
| 329 |
+
polygons = []
|
| 330 |
+
texts = []
|
| 331 |
+
ignore = []
|
| 332 |
+
|
| 333 |
+
for label in labels:
|
| 334 |
+
label_id = str(label.get('id', ''))
|
| 335 |
+
box2d = label.get('box2d', {})
|
| 336 |
+
|
| 337 |
+
if not box2d:
|
| 338 |
+
continue
|
| 339 |
+
|
| 340 |
+
# 从 box2d 转换为多边形 (4个点)
|
| 341 |
+
x1, y1 = box2d.get('x1', 0), box2d.get('y1', 0)
|
| 342 |
+
x2, y2 = box2d.get('x2', 0), box2d.get('y2', 0)
|
| 343 |
+
|
| 344 |
+
# 转换为多边形格式 [x1,y1,x2,y1,x2,y2,x1,y2]
|
| 345 |
+
polygon = [x1, y1, x2, y1, x2, y2, x1, y2]
|
| 346 |
+
polygons.append(polygon)
|
| 347 |
+
|
| 348 |
+
# 获取文本
|
| 349 |
+
text = video_texts.get(label_id, '')
|
| 350 |
+
texts.append(text)
|
| 351 |
+
|
| 352 |
+
# 判断是否忽略 (空文本或特定标记)
|
| 353 |
+
ignore_flag = (text == '' or text == '###' or text.lower() == 'illegible')
|
| 354 |
+
ignore.append(ignore_flag)
|
| 355 |
+
|
| 356 |
+
return {
|
| 357 |
+
'polygons': polygons,
|
| 358 |
+
'texts': texts,
|
| 359 |
+
'ignore': ignore,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def scale_polygon(polygon: List[float], scale_x: float, scale_y: float) -> List[float]:
|
| 364 |
+
"""缩放多边形坐标"""
|
| 365 |
+
scaled = []
|
| 366 |
+
for i in range(0, len(polygon), 2):
|
| 367 |
+
scaled.append(polygon[i] * scale_x)
|
| 368 |
+
scaled.append(polygon[i+1] * scale_y)
|
| 369 |
+
return scaled
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def main():
|
| 373 |
+
random.seed(CONFIG['seed'])
|
| 374 |
+
np.random.seed(CONFIG['seed'])
|
| 375 |
+
|
| 376 |
+
root_dir = Path(CONFIG['roadtext_root'])
|
| 377 |
+
output_dir = Path(CONFIG['output_dir'])
|
| 378 |
+
split = CONFIG['split']
|
| 379 |
+
crop_size = CONFIG['crop_size']
|
| 380 |
+
|
| 381 |
+
# 创建输出目录
|
| 382 |
+
gt_dir = output_dir / 'gt'
|
| 383 |
+
gt_dir.mkdir(parents=True, exist_ok=True)
|
| 384 |
+
|
| 385 |
+
print(f"RoadText1K root: {root_dir}")
|
| 386 |
+
print(f"Split: {split}")
|
| 387 |
+
print(f"Output dir: {output_dir}")
|
| 388 |
+
print(f"Crop size: {crop_size}x{crop_size}")
|
| 389 |
+
print(f"Crop strategy: {CONFIG['crop_strategy']}")
|
| 390 |
+
print(f"Min text boxes per crop: {CONFIG['min_text_boxes']}")
|
| 391 |
+
print()
|
| 392 |
+
|
| 393 |
+
# 加载标注
|
| 394 |
+
loc_annotations = load_localisation_annotations(root_dir)
|
| 395 |
+
text_transcriptions = load_text_transcriptions(root_dir)
|
| 396 |
+
|
| 397 |
+
# 获取视频目录
|
| 398 |
+
video_dir = root_dir / 'Videos' / split
|
| 399 |
+
if not video_dir.exists():
|
| 400 |
+
print(f"Error: Video directory not found: {video_dir}")
|
| 401 |
+
return
|
| 402 |
+
|
| 403 |
+
# 获取所有视频文件(支持嵌套子目录结构)
|
| 404 |
+
video_files = []
|
| 405 |
+
for subdir in sorted(video_dir.iterdir()):
|
| 406 |
+
if subdir.is_dir():
|
| 407 |
+
# 子目录中的视频文件
|
| 408 |
+
video_files.extend(subdir.glob('*.mp4'))
|
| 409 |
+
video_files.extend(subdir.glob('*.avi'))
|
| 410 |
+
elif subdir.suffix in ['.mp4', '.avi']:
|
| 411 |
+
# 视频文件直接在 split 目录下
|
| 412 |
+
video_files.append(subdir)
|
| 413 |
+
|
| 414 |
+
video_files = sorted(video_files)
|
| 415 |
+
total_videos = len(video_files)
|
| 416 |
+
print(f"找到 {total_videos} 个视频")
|
| 417 |
+
print()
|
| 418 |
+
|
| 419 |
+
# 处理视频
|
| 420 |
+
new_annotations = {}
|
| 421 |
+
processed_count = 0
|
| 422 |
+
total_frames = 0
|
| 423 |
+
|
| 424 |
+
print("Processing videos...")
|
| 425 |
+
print("Step 1: 找出每个视频中有标注的帧...")
|
| 426 |
+
|
| 427 |
+
# 第一步:找出每个视频中有标注的帧索引,同时记录标注 key
|
| 428 |
+
video_annotated_frames = {} # {video_name: [(frame_num, annotation_key), ...]}
|
| 429 |
+
video_name_to_path = {vp.stem: vp for vp in video_files}
|
| 430 |
+
|
| 431 |
+
for key, ann in loc_annotations.items():
|
| 432 |
+
# 从标注的 name 中提取视频名和帧号
|
| 433 |
+
# 格式可能是: "200_frames/170-0000001.jpg" 或 "test_frames/701-0000001.jpg"
|
| 434 |
+
parts = key.split('/')
|
| 435 |
+
if len(parts) >= 2:
|
| 436 |
+
filename = parts[-1] # "170-0000001.jpg"
|
| 437 |
+
else:
|
| 438 |
+
filename = key # "170-0000001.jpg"
|
| 439 |
+
|
| 440 |
+
# 提取帧号
|
| 441 |
+
if '-' in filename:
|
| 442 |
+
try:
|
| 443 |
+
frame_str = filename.split('-')[-1].split('.')[0]
|
| 444 |
+
frame_num = int(frame_str)
|
| 445 |
+
|
| 446 |
+
# 尝试提取视频名
|
| 447 |
+
video_name_candidate = filename.split('-')[0]
|
| 448 |
+
|
| 449 |
+
# 检查是否在我们的视频列表中
|
| 450 |
+
if video_name_candidate in video_name_to_path:
|
| 451 |
+
if video_name_candidate not in video_annotated_frames:
|
| 452 |
+
video_annotated_frames[video_name_candidate] = []
|
| 453 |
+
# 存储 (frame_num, annotation_key) 以便后续使用
|
| 454 |
+
entry = (frame_num, key)
|
| 455 |
+
if entry not in video_annotated_frames[video_name_candidate]:
|
| 456 |
+
video_annotated_frames[video_name_candidate].append(entry)
|
| 457 |
+
except:
|
| 458 |
+
pass
|
| 459 |
+
|
| 460 |
+
print(f"找到 {len(video_annotated_frames)} 个视频有标注")
|
| 461 |
+
total_annotated_frames = sum(len(frames) for frames in video_annotated_frames.values())
|
| 462 |
+
print(f"总共有 {total_annotated_frames} 帧有标注")
|
| 463 |
+
|
| 464 |
+
# 如果设置了 max_frames,随机选取
|
| 465 |
+
max_frames = CONFIG['max_frames']
|
| 466 |
+
if max_frames is not None and total_annotated_frames > max_frames:
|
| 467 |
+
print(f"\n限制最大帧数为 {max_frames},随机选取中...")
|
| 468 |
+
|
| 469 |
+
# 将所有帧展开为列表 [(video_name, frame_num, annotation_key), ...]
|
| 470 |
+
all_frames = []
|
| 471 |
+
for video_name, frame_list in video_annotated_frames.items():
|
| 472 |
+
for frame_num, ann_key in frame_list:
|
| 473 |
+
all_frames.append((video_name, frame_num, ann_key))
|
| 474 |
+
|
| 475 |
+
# 随机选取
|
| 476 |
+
selected_frames = random.sample(all_frames, max_frames)
|
| 477 |
+
|
| 478 |
+
# 重新组织为 video_annotated_frames 格式
|
| 479 |
+
video_annotated_frames = {}
|
| 480 |
+
for video_name, frame_num, ann_key in selected_frames:
|
| 481 |
+
if video_name not in video_annotated_frames:
|
| 482 |
+
video_annotated_frames[video_name] = []
|
| 483 |
+
video_annotated_frames[video_name].append((frame_num, ann_key))
|
| 484 |
+
|
| 485 |
+
print(f"随机选取 {max_frames} 帧,涉及 {len(video_annotated_frames)} 个视频")
|
| 486 |
+
|
| 487 |
+
crop_size = CONFIG['crop_size']
|
| 488 |
+
crop_strategy = CONFIG['crop_strategy']
|
| 489 |
+
min_box_overlap = CONFIG['min_box_overlap']
|
| 490 |
+
min_text_boxes = CONFIG['min_text_boxes']
|
| 491 |
+
|
| 492 |
+
print()
|
| 493 |
+
print("Step 2: 提取帧并 Crop(保留有效文本框)...")
|
| 494 |
+
print(f" Crop 尺寸: {crop_size}x{crop_size}")
|
| 495 |
+
print(f" Crop 策略: {crop_strategy}")
|
| 496 |
+
print(f" 最小重叠比例: {min_box_overlap:.0%}")
|
| 497 |
+
print(f" 最小文本框数: {min_text_boxes}")
|
| 498 |
+
print()
|
| 499 |
+
|
| 500 |
+
skipped_no_boxes = 0
|
| 501 |
+
|
| 502 |
+
for video_path in tqdm(video_files, desc="Videos"):
|
| 503 |
+
video_name = video_path.stem
|
| 504 |
+
|
| 505 |
+
# 获取该视频有标注的帧信息 [(frame_num, ann_key), ...]
|
| 506 |
+
annotated_frame_info = video_annotated_frames.get(video_name, [])
|
| 507 |
+
|
| 508 |
+
if not annotated_frame_info:
|
| 509 |
+
continue
|
| 510 |
+
|
| 511 |
+
# 获取视频分辨率
|
| 512 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 513 |
+
if not cap.isOpened():
|
| 514 |
+
continue
|
| 515 |
+
orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 516 |
+
orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 517 |
+
cap.release()
|
| 518 |
+
|
| 519 |
+
if orig_w == 0 or orig_h == 0:
|
| 520 |
+
continue
|
| 521 |
+
|
| 522 |
+
# 处理每一帧
|
| 523 |
+
for frame_num, ann_key in annotated_frame_info:
|
| 524 |
+
if ann_key not in loc_annotations:
|
| 525 |
+
continue
|
| 526 |
+
|
| 527 |
+
loc_ann = loc_annotations[ann_key]
|
| 528 |
+
if loc_ann is None:
|
| 529 |
+
continue
|
| 530 |
+
|
| 531 |
+
# 合并标注
|
| 532 |
+
merged_ann = merge_annotations(loc_ann, text_transcriptions)
|
| 533 |
+
|
| 534 |
+
if not merged_ann['polygons']:
|
| 535 |
+
skipped_no_boxes += 1
|
| 536 |
+
continue
|
| 537 |
+
|
| 538 |
+
# 找到最佳的 crop 位置
|
| 539 |
+
crop_x, crop_y, valid_indices = find_best_crop_position(
|
| 540 |
+
merged_ann['polygons'],
|
| 541 |
+
orig_w,
|
| 542 |
+
orig_h,
|
| 543 |
+
crop_size,
|
| 544 |
+
crop_strategy,
|
| 545 |
+
min_box_overlap
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# 检查有效文本框数量
|
| 549 |
+
if len(valid_indices) < min_text_boxes:
|
| 550 |
+
skipped_no_boxes += 1
|
| 551 |
+
continue
|
| 552 |
+
|
| 553 |
+
# 提取并 crop 帧
|
| 554 |
+
img_filename, success = extract_frame_with_crop(
|
| 555 |
+
video_path, frame_num, gt_dir,
|
| 556 |
+
crop_x, crop_y, crop_size
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
if not success:
|
| 560 |
+
continue
|
| 561 |
+
|
| 562 |
+
# 只保留有效的文本框,裁剪并调整坐标
|
| 563 |
+
cropped_polygons = []
|
| 564 |
+
cropped_texts = []
|
| 565 |
+
cropped_ignore = []
|
| 566 |
+
|
| 567 |
+
for i in valid_indices:
|
| 568 |
+
# 裁剪多边形到 crop 区域并转换坐标
|
| 569 |
+
clipped_poly = clip_polygon_to_crop(
|
| 570 |
+
merged_ann['polygons'][i], crop_x, crop_y, crop_size
|
| 571 |
+
)
|
| 572 |
+
cropped_polygons.append(clipped_poly)
|
| 573 |
+
cropped_texts.append(merged_ann['texts'][i])
|
| 574 |
+
cropped_ignore.append(merged_ann['ignore'][i])
|
| 575 |
+
|
| 576 |
+
new_annotations[img_filename] = {
|
| 577 |
+
'polygons': cropped_polygons,
|
| 578 |
+
'texts': cropped_texts,
|
| 579 |
+
'ignore': cropped_ignore,
|
| 580 |
+
'original_name': ann_key,
|
| 581 |
+
'crop_position': [crop_x, crop_y],
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
total_frames += 1
|
| 585 |
+
|
| 586 |
+
processed_count += 1
|
| 587 |
+
|
| 588 |
+
# 保存标注
|
| 589 |
+
ann_output_path = output_dir / 'annotations.json'
|
| 590 |
+
with open(ann_output_path, 'w', encoding='utf-8') as f:
|
| 591 |
+
json.dump(new_annotations, f, indent=2, ensure_ascii=False)
|
| 592 |
+
|
| 593 |
+
# 统计有效文本框
|
| 594 |
+
total_boxes = sum(len(ann['texts']) for ann in new_annotations.values())
|
| 595 |
+
valid_boxes = sum(
|
| 596 |
+
sum(1 for ig in ann['ignore'] if not ig)
|
| 597 |
+
for ann in new_annotations.values()
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
print()
|
| 601 |
+
print("="*60)
|
| 602 |
+
print("完成!")
|
| 603 |
+
print("="*60)
|
| 604 |
+
print(f"处理视频数: {processed_count}")
|
| 605 |
+
print(f"提取帧数: {total_frames}")
|
| 606 |
+
print(f"跳过帧数(文本框不足): {skipped_no_boxes}")
|
| 607 |
+
print(f"有标注的图片数: {len(new_annotations)}")
|
| 608 |
+
print(f"文本框总数: {total_boxes}")
|
| 609 |
+
print(f"有效文本框: {valid_boxes}")
|
| 610 |
+
print(f"每张图平均文本框: {total_boxes/max(1,len(new_annotations)):.1f}")
|
| 611 |
+
print()
|
| 612 |
+
print("输出文件:")
|
| 613 |
+
print(f" GT images ({crop_size}x{crop_size}): {gt_dir}")
|
| 614 |
+
print(f" Annotations: {ann_output_path}")
|
| 615 |
+
print()
|
| 616 |
+
print("下一步:")
|
| 617 |
+
print(f" 1. 用你的方式生成 LR images (如 128x128)")
|
| 618 |
+
print(f" 2. 超分得到 SR images ({crop_size}x{crop_size})")
|
| 619 |
+
print(f" 3. 将 SR images 保存到 {output_dir}/sr")
|
| 620 |
+
print(f" 4. 运行 test_roadtext.py 评估")
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
if __name__ == '__main__':
|
| 624 |
+
main()
|
| 625 |
+
|
diffusion-dpo-ocr/results/roadtext_eval_results_BSRGAN.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 51,
|
| 3 |
+
"FP": 80,
|
| 4 |
+
"FN": 1920,
|
| 5 |
+
"Precision": 0.3893129770992366,
|
| 6 |
+
"Recall": 0.0258751902587519,
|
| 7 |
+
"F1-Score": 0.04852521408182683,
|
| 8 |
+
"OCR_detections": 131,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.06646372399797057,
|
| 11 |
+
"Det_Precision": 0.6183206106870229,
|
| 12 |
+
"Det_Recall": 0.0410958904109589,
|
| 13 |
+
"Det_F1": 0.07706945765937202,
|
| 14 |
+
"Det_matched": 81,
|
| 15 |
+
"Text_matched": 51,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_DP2O-SR.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 314,
|
| 3 |
+
"FP": 1182,
|
| 4 |
+
"FN": 1657,
|
| 5 |
+
"Precision": 0.20989304812834225,
|
| 6 |
+
"Recall": 0.15930999492643327,
|
| 7 |
+
"F1-Score": 0.181136429189501,
|
| 8 |
+
"OCR_detections": 1496,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.7590055809233891,
|
| 11 |
+
"Det_Precision": 0.5274064171122995,
|
| 12 |
+
"Det_Recall": 0.4003044140030441,
|
| 13 |
+
"Det_F1": 0.4551485434092875,
|
| 14 |
+
"Det_matched": 789,
|
| 15 |
+
"Text_matched": 314,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_DiT4SR.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 226,
|
| 3 |
+
"FP": 903,
|
| 4 |
+
"FN": 1745,
|
| 5 |
+
"Precision": 0.20017714791851196,
|
| 6 |
+
"Recall": 0.11466260781329274,
|
| 7 |
+
"F1-Score": 0.14580645161290323,
|
| 8 |
+
"OCR_detections": 1129,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.5728056823947235,
|
| 11 |
+
"Det_Precision": 0.5234720992028343,
|
| 12 |
+
"Det_Recall": 0.2998477929984779,
|
| 13 |
+
"Det_F1": 0.3812903225806451,
|
| 14 |
+
"Det_matched": 591,
|
| 15 |
+
"Text_matched": 226,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_DiffBIR.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 172,
|
| 3 |
+
"FP": 687,
|
| 4 |
+
"FN": 1799,
|
| 5 |
+
"Precision": 0.20023282887077998,
|
| 6 |
+
"Recall": 0.08726534753932014,
|
| 7 |
+
"F1-Score": 0.1215547703180212,
|
| 8 |
+
"OCR_detections": 859,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.43581938102486045,
|
| 11 |
+
"Det_Precision": 0.5448195576251456,
|
| 12 |
+
"Det_Recall": 0.2374429223744292,
|
| 13 |
+
"Det_F1": 0.33074204946996466,
|
| 14 |
+
"Det_matched": 468,
|
| 15 |
+
"Text_matched": 172,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_FaithDiff.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 210,
|
| 3 |
+
"FP": 1068,
|
| 4 |
+
"FN": 1761,
|
| 5 |
+
"Precision": 0.1643192488262911,
|
| 6 |
+
"Recall": 0.106544901065449,
|
| 7 |
+
"F1-Score": 0.12927054478301014,
|
| 8 |
+
"OCR_detections": 1278,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.6484018264840182,
|
| 11 |
+
"Det_Precision": 0.4890453834115806,
|
| 12 |
+
"Det_Recall": 0.31709791983764585,
|
| 13 |
+
"Det_F1": 0.38473376423514927,
|
| 14 |
+
"Det_matched": 625,
|
| 15 |
+
"Text_matched": 210,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_Ours.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 233,
|
| 3 |
+
"FP": 1483,
|
| 4 |
+
"FN": 1738,
|
| 5 |
+
"Precision": 0.1357808857808858,
|
| 6 |
+
"Recall": 0.11821410451547437,
|
| 7 |
+
"F1-Score": 0.12639001898562516,
|
| 8 |
+
"OCR_detections": 1716,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.8706240487062404,
|
| 11 |
+
"Det_Precision": 0.3275058275058275,
|
| 12 |
+
"Det_Recall": 0.28513444951801115,
|
| 13 |
+
"Det_F1": 0.3048548955790616,
|
| 14 |
+
"Det_matched": 562,
|
| 15 |
+
"Text_matched": 233,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_Real-ESRGAN.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 64,
|
| 3 |
+
"FP": 191,
|
| 4 |
+
"FN": 1907,
|
| 5 |
+
"Precision": 0.25098039215686274,
|
| 6 |
+
"Recall": 0.032470826991374935,
|
| 7 |
+
"F1-Score": 0.05750224618149146,
|
| 8 |
+
"OCR_detections": 255,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.1293759512937595,
|
| 11 |
+
"Det_Precision": 0.5882352941176471,
|
| 12 |
+
"Det_Recall": 0.076103500761035,
|
| 13 |
+
"Det_F1": 0.1347708894878706,
|
| 14 |
+
"Det_matched": 150,
|
| 15 |
+
"Text_matched": 64,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_SUPSR.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 219,
|
| 3 |
+
"FP": 1267,
|
| 4 |
+
"FN": 1752,
|
| 5 |
+
"Precision": 0.14737550471063257,
|
| 6 |
+
"Recall": 0.1111111111111111,
|
| 7 |
+
"F1-Score": 0.126699450390512,
|
| 8 |
+
"OCR_detections": 1486,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.7539320142059868,
|
| 11 |
+
"Det_Precision": 0.4878869448183042,
|
| 12 |
+
"Det_Recall": 0.3678335870116692,
|
| 13 |
+
"Det_F1": 0.41943881978594155,
|
| 14 |
+
"Det_matched": 725,
|
| 15 |
+
"Text_matched": 219,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_SeeSR.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 0,
|
| 3 |
+
"FP": 0,
|
| 4 |
+
"FN": 0,
|
| 5 |
+
"Precision": 0,
|
| 6 |
+
"Recall": 0,
|
| 7 |
+
"F1-Score": 0,
|
| 8 |
+
"OCR_detections": 0,
|
| 9 |
+
"GT_boxes": 0,
|
| 10 |
+
"Detection_rate": 0,
|
| 11 |
+
"Det_Precision": 0,
|
| 12 |
+
"Det_Recall": 0,
|
| 13 |
+
"Det_F1": 0,
|
| 14 |
+
"Det_matched": 0,
|
| 15 |
+
"Text_matched": 0,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_StableSR.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 141,
|
| 3 |
+
"FP": 672,
|
| 4 |
+
"FN": 1830,
|
| 5 |
+
"Precision": 0.17343173431734318,
|
| 6 |
+
"Recall": 0.0715372907153729,
|
| 7 |
+
"F1-Score": 0.10129310344827586,
|
| 8 |
+
"OCR_detections": 813,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.4124809741248097,
|
| 11 |
+
"Det_Precision": 0.5276752767527675,
|
| 12 |
+
"Det_Recall": 0.2176560121765601,
|
| 13 |
+
"Det_F1": 0.3081896551724138,
|
| 14 |
+
"Det_matched": 429,
|
| 15 |
+
"Text_matched": 141,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_SwinIR.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 66,
|
| 3 |
+
"FP": 182,
|
| 4 |
+
"FN": 1905,
|
| 5 |
+
"Precision": 0.2661290322580645,
|
| 6 |
+
"Recall": 0.0334855403348554,
|
| 7 |
+
"F1-Score": 0.05948625506985128,
|
| 8 |
+
"OCR_detections": 248,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.12582445459157787,
|
| 11 |
+
"Det_Precision": 0.5806451612903226,
|
| 12 |
+
"Det_Recall": 0.0730593607305936,
|
| 13 |
+
"Det_F1": 0.12978819287967552,
|
| 14 |
+
"Det_matched": 144,
|
| 15 |
+
"Text_matched": 66,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_gt.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 953,
|
| 3 |
+
"FP": 576,
|
| 4 |
+
"FN": 1018,
|
| 5 |
+
"Precision": 0.6232831916285154,
|
| 6 |
+
"Recall": 0.4835109081684424,
|
| 7 |
+
"F1-Score": 0.5445714285714285,
|
| 8 |
+
"OCR_detections": 1529,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.7757483510908169,
|
| 11 |
+
"Det_Precision": 0.6487900588620014,
|
| 12 |
+
"Det_Recall": 0.5032978183663115,
|
| 13 |
+
"Det_F1": 0.5668571428571428,
|
| 14 |
+
"Det_matched": 992,
|
| 15 |
+
"Text_matched": 953,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_sample00.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 0,
|
| 3 |
+
"FP": 0,
|
| 4 |
+
"FN": 0,
|
| 5 |
+
"Precision": 0,
|
| 6 |
+
"Recall": 0,
|
| 7 |
+
"F1-Score": 0,
|
| 8 |
+
"OCR_detections": 0,
|
| 9 |
+
"GT_boxes": 0,
|
| 10 |
+
"Detection_rate": 0,
|
| 11 |
+
"Det_Precision": 0,
|
| 12 |
+
"Det_Recall": 0,
|
| 13 |
+
"Det_F1": 0,
|
| 14 |
+
"Det_matched": 0,
|
| 15 |
+
"Text_matched": 0,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/results/roadtext_eval_results_zoomlr.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 8,
|
| 3 |
+
"FP": 106,
|
| 4 |
+
"FN": 1963,
|
| 5 |
+
"Precision": 0.07017543859649122,
|
| 6 |
+
"Recall": 0.004058853373921867,
|
| 7 |
+
"F1-Score": 0.007673860911270983,
|
| 8 |
+
"OCR_detections": 114,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.0578386605783866,
|
| 11 |
+
"Det_Precision": 0.07017543859649122,
|
| 12 |
+
"Det_Recall": 0.004058853373921867,
|
| 13 |
+
"Det_F1": 0.007673860911270983,
|
| 14 |
+
"Det_matched": 8,
|
| 15 |
+
"Text_matched": 8,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/roadtext_eval_results_output.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TP": 167,
|
| 3 |
+
"FP": 707,
|
| 4 |
+
"FN": 1804,
|
| 5 |
+
"Precision": 0.19107551487414187,
|
| 6 |
+
"Recall": 0.08472856418061897,
|
| 7 |
+
"F1-Score": 0.11739894551845341,
|
| 8 |
+
"OCR_detections": 874,
|
| 9 |
+
"GT_boxes": 1971,
|
| 10 |
+
"Detection_rate": 0.44342973110096395,
|
| 11 |
+
"Det_Precision": 0.540045766590389,
|
| 12 |
+
"Det_Recall": 0.23947234906139014,
|
| 13 |
+
"Det_F1": 0.3318101933216168,
|
| 14 |
+
"Det_matched": 472,
|
| 15 |
+
"Text_matched": 167,
|
| 16 |
+
"eval_mode": "end2end"
|
| 17 |
+
}
|
diffusion-dpo-ocr/test_roadtext.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RoadText1K OCR 评估脚本
|
| 3 |
+
评估超分图片在 OCR 任务上的 Precision 和 Recall
|
| 4 |
+
|
| 5 |
+
Metrics:
|
| 6 |
+
- Precision: TP / (TP + FP)
|
| 7 |
+
- Recall: TP / (TP + FN)
|
| 8 |
+
- F1-Score: 2 * Precision * Recall / (Precision + Recall)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import List, Dict, Tuple
|
| 15 |
+
import difflib
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ============================================================================
|
| 23 |
+
# 配置参数 - 请修改这里
|
| 24 |
+
# ============================================================================
|
| 25 |
+
CONFIG = {
|
| 26 |
+
# SR images 目录
|
| 27 |
+
'sr_dir': '/home/wanghongbo06/baipurui/results/RoadText/DreamClear/results/output',
|
| 28 |
+
# 'sr_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/SR-Eval/zoomlr',
|
| 29 |
+
|
| 30 |
+
# 标注文件 (prepare_roadtext.py 生成的)
|
| 31 |
+
'annotation_file': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/annotations.json',
|
| 32 |
+
|
| 33 |
+
# OCR 引擎选择: 'paddleocr', 'easyocr', 或 'tesseract'
|
| 34 |
+
'ocr_engine': 'paddleocr',
|
| 35 |
+
|
| 36 |
+
# 匹配参数
|
| 37 |
+
'iou_threshold': 0.3, # 检测框 IoU 阈值
|
| 38 |
+
'text_similarity_threshold': 0.3, # 文本相似度阈值 (0 = 只看检测,不看识别)
|
| 39 |
+
|
| 40 |
+
# 评估模式:
|
| 41 |
+
# 'detection_only' - 只评估检测 (忽略文本内容)
|
| 42 |
+
# 'end2end' - 端到端评估 (检测 + 识别)
|
| 43 |
+
# 'recognition_only' - 只评估识别 (在GT框上裁剪后识别)
|
| 44 |
+
'eval_mode': 'end2end',
|
| 45 |
+
|
| 46 |
+
# OCR 配置
|
| 47 |
+
'device': 'gpu', # 'gpu' 或 'cpu'
|
| 48 |
+
|
| 49 |
+
# 调试选项
|
| 50 |
+
'debug_visualize': True, # 是否可视化前几张图的检测框
|
| 51 |
+
'debug_save_dir': './debug_vis', # 可视化保存目录
|
| 52 |
+
|
| 53 |
+
# 输出
|
| 54 |
+
'output': './roadtext_eval_results.json',
|
| 55 |
+
}
|
| 56 |
+
# ============================================================================
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_ocr_model(engine='paddleocr', device='gpu'):
|
| 60 |
+
"""加载 OCR 模型"""
|
| 61 |
+
if engine == 'paddleocr':
|
| 62 |
+
try:
|
| 63 |
+
from paddleocr import PaddleOCR
|
| 64 |
+
print("Loading PaddleOCR...")
|
| 65 |
+
ocr = PaddleOCR(
|
| 66 |
+
lang='en',
|
| 67 |
+
device=device,
|
| 68 |
+
)
|
| 69 |
+
return ocr, 'paddleocr'
|
| 70 |
+
except ImportError:
|
| 71 |
+
print("PaddleOCR not found. Install: pip install paddleocr")
|
| 72 |
+
raise
|
| 73 |
+
|
| 74 |
+
elif engine == 'easyocr':
|
| 75 |
+
try:
|
| 76 |
+
import easyocr
|
| 77 |
+
print("Loading EasyOCR...")
|
| 78 |
+
use_gpu = (device == 'gpu')
|
| 79 |
+
reader = easyocr.Reader(['en'], gpu=use_gpu)
|
| 80 |
+
return reader, 'easyocr'
|
| 81 |
+
except ImportError:
|
| 82 |
+
print("EasyOCR not found. Install: pip install easyocr")
|
| 83 |
+
raise
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"Unsupported OCR engine: {engine}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def run_ocr(image_path: str, ocr_model, engine: str) -> List[Dict]:
|
| 90 |
+
"""
|
| 91 |
+
运行 OCR
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
List of dicts with keys: 'polygon', 'text', 'confidence'
|
| 95 |
+
"""
|
| 96 |
+
results = []
|
| 97 |
+
|
| 98 |
+
if engine == 'paddleocr':
|
| 99 |
+
ocr_result = ocr_model.ocr(str(image_path))
|
| 100 |
+
|
| 101 |
+
if ocr_result and ocr_result[0]:
|
| 102 |
+
for line in ocr_result[0]:
|
| 103 |
+
polygon = np.array(line[0]).flatten().tolist() # [[x1,y1],[x2,y2],...] -> [x1,y1,x2,y2,...]
|
| 104 |
+
text = line[1][0]
|
| 105 |
+
confidence = line[1][1]
|
| 106 |
+
|
| 107 |
+
results.append({
|
| 108 |
+
'polygon': polygon,
|
| 109 |
+
'text': text,
|
| 110 |
+
'confidence': confidence,
|
| 111 |
+
})
|
| 112 |
+
|
| 113 |
+
elif engine == 'easyocr':
|
| 114 |
+
ocr_result = ocr_model.readtext(str(image_path))
|
| 115 |
+
|
| 116 |
+
for detection in ocr_result:
|
| 117 |
+
polygon = np.array(detection[0]).flatten().tolist()
|
| 118 |
+
text = detection[1]
|
| 119 |
+
confidence = detection[2]
|
| 120 |
+
|
| 121 |
+
results.append({
|
| 122 |
+
'polygon': polygon,
|
| 123 |
+
'text': text,
|
| 124 |
+
'confidence': confidence,
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
return results
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def polygon_iou(poly1: List[float], poly2: List[float]) -> float:
|
| 131 |
+
"""
|
| 132 |
+
计算两个多边形的 IoU
|
| 133 |
+
poly: [x1,y1,x2,y2,x3,y3,x4,y4]
|
| 134 |
+
"""
|
| 135 |
+
try:
|
| 136 |
+
from shapely.geometry import Polygon
|
| 137 |
+
|
| 138 |
+
# 转换为 Polygon 对象
|
| 139 |
+
p1 = Polygon([(poly1[i], poly1[i+1]) for i in range(0, len(poly1), 2)])
|
| 140 |
+
p2 = Polygon([(poly2[i], poly2[i+1]) for i in range(0, len(poly2), 2)])
|
| 141 |
+
|
| 142 |
+
if not p1.is_valid or not p2.is_valid:
|
| 143 |
+
return 0.0
|
| 144 |
+
|
| 145 |
+
# 计算 IoU
|
| 146 |
+
intersection = p1.intersection(p2).area
|
| 147 |
+
union = p1.union(p2).area
|
| 148 |
+
|
| 149 |
+
if union == 0:
|
| 150 |
+
return 0.0
|
| 151 |
+
|
| 152 |
+
return intersection / union
|
| 153 |
+
|
| 154 |
+
except ImportError:
|
| 155 |
+
# 使用 bbox IoU 近似
|
| 156 |
+
return bbox_iou_from_polygon(poly1, poly2)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def bbox_iou_from_polygon(poly1: List[float], poly2: List[float]) -> float:
|
| 160 |
+
"""使用 bbox 近似计算 IoU"""
|
| 161 |
+
# 获取 bbox
|
| 162 |
+
x1_min = min(poly1[0::2])
|
| 163 |
+
y1_min = min(poly1[1::2])
|
| 164 |
+
x1_max = max(poly1[0::2])
|
| 165 |
+
y1_max = max(poly1[1::2])
|
| 166 |
+
|
| 167 |
+
x2_min = min(poly2[0::2])
|
| 168 |
+
y2_min = min(poly2[1::2])
|
| 169 |
+
x2_max = max(poly2[0::2])
|
| 170 |
+
y2_max = max(poly2[1::2])
|
| 171 |
+
|
| 172 |
+
# 计算交集
|
| 173 |
+
inter_xmin = max(x1_min, x2_min)
|
| 174 |
+
inter_ymin = max(y1_min, y2_min)
|
| 175 |
+
inter_xmax = min(x1_max, x2_max)
|
| 176 |
+
inter_ymax = min(y1_max, y2_max)
|
| 177 |
+
|
| 178 |
+
if inter_xmax <= inter_xmin or inter_ymax <= inter_ymin:
|
| 179 |
+
return 0.0
|
| 180 |
+
|
| 181 |
+
inter_area = (inter_xmax - inter_xmin) * (inter_ymax - inter_ymin)
|
| 182 |
+
area1 = (x1_max - x1_min) * (y1_max - y1_min)
|
| 183 |
+
area2 = (x2_max - x2_min) * (y2_max - y2_min)
|
| 184 |
+
union_area = area1 + area2 - inter_area
|
| 185 |
+
|
| 186 |
+
return inter_area / union_area if union_area > 0 else 0.0
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def text_similarity(text1: str, text2: str) -> float:
|
| 190 |
+
"""计算文本相似度"""
|
| 191 |
+
text1 = text1.lower().strip()
|
| 192 |
+
text2 = text2.lower().strip()
|
| 193 |
+
|
| 194 |
+
if text1 == text2:
|
| 195 |
+
return 1.0
|
| 196 |
+
|
| 197 |
+
# 使用编辑距离
|
| 198 |
+
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def match_detections(
|
| 202 |
+
pred_results: List[Dict],
|
| 203 |
+
gt_annotations: Dict,
|
| 204 |
+
iou_thresh: float = 0.5,
|
| 205 |
+
text_sim_thresh: float = 0.5,
|
| 206 |
+
eval_mode: str = 'end2end',
|
| 207 |
+
) -> Tuple[int, int, int, Dict]:
|
| 208 |
+
"""
|
| 209 |
+
匹配预测和GT
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
eval_mode: 'detection_only', 'end2end', 'recognition_only'
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
(TP, FP, FN, details)
|
| 216 |
+
"""
|
| 217 |
+
gt_polygons = gt_annotations['polygons']
|
| 218 |
+
gt_texts = gt_annotations['texts']
|
| 219 |
+
gt_ignore = gt_annotations.get('ignore', [False] * len(gt_texts))
|
| 220 |
+
|
| 221 |
+
matched_gt = set()
|
| 222 |
+
tp = 0
|
| 223 |
+
fp = 0
|
| 224 |
+
|
| 225 |
+
# 详细统计
|
| 226 |
+
details = {
|
| 227 |
+
'det_matched': 0, # 检测匹配数 (IoU 通过)
|
| 228 |
+
'text_matched': 0, # 文本匹配数
|
| 229 |
+
'det_fp': 0, # 检测误检
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
# 遍历预测结果
|
| 233 |
+
for pred in pred_results:
|
| 234 |
+
pred_poly = pred['polygon']
|
| 235 |
+
pred_text = pred['text']
|
| 236 |
+
|
| 237 |
+
best_iou = 0
|
| 238 |
+
best_gt_idx = -1
|
| 239 |
+
|
| 240 |
+
# 找到最佳匹配的 GT
|
| 241 |
+
for gt_idx, (gt_poly, gt_text, ignore) in enumerate(zip(gt_polygons, gt_texts, gt_ignore)):
|
| 242 |
+
if ignore or gt_idx in matched_gt:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
iou = polygon_iou(pred_poly, gt_poly)
|
| 246 |
+
|
| 247 |
+
if iou > best_iou:
|
| 248 |
+
best_iou = iou
|
| 249 |
+
best_gt_idx = gt_idx
|
| 250 |
+
|
| 251 |
+
# 判断是否匹配成功
|
| 252 |
+
if best_iou >= iou_thresh and best_gt_idx >= 0:
|
| 253 |
+
details['det_matched'] += 1
|
| 254 |
+
|
| 255 |
+
if eval_mode == 'detection_only':
|
| 256 |
+
# 只看检测,不看文本
|
| 257 |
+
tp += 1
|
| 258 |
+
matched_gt.add(best_gt_idx)
|
| 259 |
+
else:
|
| 260 |
+
# 检查文本相似度
|
| 261 |
+
gt_text = gt_texts[best_gt_idx]
|
| 262 |
+
sim = text_similarity(pred_text, gt_text)
|
| 263 |
+
|
| 264 |
+
if sim >= text_sim_thresh:
|
| 265 |
+
tp += 1
|
| 266 |
+
matched_gt.add(best_gt_idx)
|
| 267 |
+
details['text_matched'] += 1
|
| 268 |
+
else:
|
| 269 |
+
# 根据评估模式决定是否计为 FP
|
| 270 |
+
if eval_mode == 'end2end':
|
| 271 |
+
fp += 1 # 位置对但文字错,计为 FP
|
| 272 |
+
else:
|
| 273 |
+
fp += 1 # 误检
|
| 274 |
+
details['det_fp'] += 1
|
| 275 |
+
|
| 276 |
+
# 计算 FN (未检测到的GT)
|
| 277 |
+
fn = 0
|
| 278 |
+
for gt_idx, ignore in enumerate(gt_ignore):
|
| 279 |
+
if not ignore and gt_idx not in matched_gt:
|
| 280 |
+
fn += 1
|
| 281 |
+
|
| 282 |
+
return tp, fp, fn, details
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def visualize_detections(
|
| 286 |
+
img_path: Path,
|
| 287 |
+
gt_ann: Dict,
|
| 288 |
+
pred_results: List[Dict],
|
| 289 |
+
save_path: Path,
|
| 290 |
+
):
|
| 291 |
+
"""可视化 GT 和 OCR 检测框对比"""
|
| 292 |
+
import cv2
|
| 293 |
+
|
| 294 |
+
img = cv2.imread(str(img_path))
|
| 295 |
+
if img is None:
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
# 绘制 GT 框 (绿色)
|
| 299 |
+
for i, (poly, text, ignore) in enumerate(zip(
|
| 300 |
+
gt_ann['polygons'], gt_ann['texts'], gt_ann.get('ignore', [False] * len(gt_ann['texts']))
|
| 301 |
+
)):
|
| 302 |
+
if ignore:
|
| 303 |
+
continue
|
| 304 |
+
pts = np.array(poly).reshape(-1, 2).astype(np.int32)
|
| 305 |
+
cv2.polylines(img, [pts], True, (0, 255, 0), 2)
|
| 306 |
+
# 标注 GT 文本
|
| 307 |
+
x, y = int(pts[0][0]), int(pts[0][1]) - 5
|
| 308 |
+
cv2.putText(img, f"GT:{text[:10]}", (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
|
| 309 |
+
|
| 310 |
+
# 绘制 OCR 检测框 (红色)
|
| 311 |
+
for pred in pred_results:
|
| 312 |
+
poly = pred['polygon']
|
| 313 |
+
text = pred['text']
|
| 314 |
+
pts = np.array(poly).reshape(-1, 2).astype(np.int32)
|
| 315 |
+
cv2.polylines(img, [pts], True, (0, 0, 255), 2)
|
| 316 |
+
# 标注 OCR 文本
|
| 317 |
+
x, y = int(pts[0][0]), int(pts[0][1]) + 15
|
| 318 |
+
cv2.putText(img, f"OCR:{text[:10]}", (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
|
| 319 |
+
|
| 320 |
+
cv2.imwrite(str(save_path), img)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def evaluate_dataset(
|
| 324 |
+
sr_dir: str,
|
| 325 |
+
annotation_file: str,
|
| 326 |
+
ocr_model,
|
| 327 |
+
engine: str,
|
| 328 |
+
debug: bool = False,
|
| 329 |
+
) -> Dict:
|
| 330 |
+
"""评估整个数据集"""
|
| 331 |
+
sr_dir = Path(sr_dir)
|
| 332 |
+
eval_mode = CONFIG.get('eval_mode', 'end2end')
|
| 333 |
+
debug_visualize = CONFIG.get('debug_visualize', False)
|
| 334 |
+
debug_save_dir = Path(CONFIG.get('debug_save_dir', './debug_vis'))
|
| 335 |
+
|
| 336 |
+
if debug_visualize:
|
| 337 |
+
debug_save_dir.mkdir(parents=True, exist_ok=True)
|
| 338 |
+
|
| 339 |
+
# 加载标注
|
| 340 |
+
with open(annotation_file, 'r', encoding='utf-8') as f:
|
| 341 |
+
annotations = json.load(f)
|
| 342 |
+
|
| 343 |
+
# 调试:统计标注信息
|
| 344 |
+
total_gt_boxes = 0
|
| 345 |
+
total_ignored = 0
|
| 346 |
+
for img_name, ann in annotations.items():
|
| 347 |
+
total_gt_boxes += len(ann['texts'])
|
| 348 |
+
total_ignored += sum(ann.get('ignore', [False] * len(ann['texts'])))
|
| 349 |
+
print(f"\n标注统计: 总共 {len(annotations)} 张图片, {total_gt_boxes} 个文本框, {total_ignored} 个被忽略")
|
| 350 |
+
print(f"评估模式: {eval_mode}")
|
| 351 |
+
|
| 352 |
+
total_tp = 0
|
| 353 |
+
total_fp = 0
|
| 354 |
+
total_fn = 0
|
| 355 |
+
|
| 356 |
+
# 汇总详细统计
|
| 357 |
+
total_details = {
|
| 358 |
+
'det_matched': 0,
|
| 359 |
+
'text_matched': 0,
|
| 360 |
+
'det_fp': 0,
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
print("Running OCR evaluation...")
|
| 364 |
+
total_ocr_detections = 0
|
| 365 |
+
debug_count = 0
|
| 366 |
+
vis_count = 0
|
| 367 |
+
|
| 368 |
+
for img_name, gt_ann in tqdm(annotations.items()):
|
| 369 |
+
img_path = sr_dir / img_name
|
| 370 |
+
|
| 371 |
+
if not img_path.exists():
|
| 372 |
+
# 尝试其他扩展名
|
| 373 |
+
img_path = sr_dir / (Path(img_name).stem + '.jpg')
|
| 374 |
+
if not img_path.exists():
|
| 375 |
+
continue
|
| 376 |
+
|
| 377 |
+
# 运行 OCR
|
| 378 |
+
pred_results = run_ocr(img_path, ocr_model, engine)
|
| 379 |
+
total_ocr_detections += len(pred_results)
|
| 380 |
+
|
| 381 |
+
# 调试:打印前3张图片的详细信息
|
| 382 |
+
if debug_count < 3:
|
| 383 |
+
print(f"\n[DEBUG] Image: {img_name}")
|
| 384 |
+
print(f" GT boxes: {len(gt_ann['polygons'])}, ignored: {sum(gt_ann.get('ignore', []))}")
|
| 385 |
+
print(f" OCR detections: {len(pred_results)}")
|
| 386 |
+
if gt_ann['polygons']:
|
| 387 |
+
print(f" GT[0] polygon: {gt_ann['polygons'][0][:4]}... text: '{gt_ann['texts'][0]}'")
|
| 388 |
+
if pred_results:
|
| 389 |
+
print(f" OCR[0] polygon: {pred_results[0]['polygon'][:4]}... text: '{pred_results[0]['text']}'")
|
| 390 |
+
debug_count += 1
|
| 391 |
+
|
| 392 |
+
# 可视化
|
| 393 |
+
if debug_visualize and vis_count < 20:
|
| 394 |
+
vis_path = debug_save_dir / f"vis_{img_name}"
|
| 395 |
+
visualize_detections(img_path, gt_ann, pred_results, vis_path)
|
| 396 |
+
vis_count += 1
|
| 397 |
+
|
| 398 |
+
# 匹配
|
| 399 |
+
tp, fp, fn, details = match_detections(
|
| 400 |
+
pred_results,
|
| 401 |
+
gt_ann,
|
| 402 |
+
iou_thresh=CONFIG['iou_threshold'],
|
| 403 |
+
text_sim_thresh=CONFIG['text_similarity_threshold'],
|
| 404 |
+
eval_mode=eval_mode,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
total_tp += tp
|
| 408 |
+
total_fp += fp
|
| 409 |
+
total_fn += fn
|
| 410 |
+
|
| 411 |
+
for k, v in details.items():
|
| 412 |
+
total_details[k] += v
|
| 413 |
+
|
| 414 |
+
print(f"\nOCR 总共检测到 {total_ocr_detections} 个文本框")
|
| 415 |
+
|
| 416 |
+
# 计算指标
|
| 417 |
+
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
|
| 418 |
+
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
|
| 419 |
+
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
| 420 |
+
|
| 421 |
+
# 统计 GT 信息 (不含 ignore)
|
| 422 |
+
total_gt_boxes_valid = total_tp + total_fn
|
| 423 |
+
|
| 424 |
+
# 计算纯检测指标 (不考虑文本内容)
|
| 425 |
+
det_precision = total_details['det_matched'] / total_ocr_detections if total_ocr_detections > 0 else 0
|
| 426 |
+
det_recall = total_details['det_matched'] / total_gt_boxes_valid if total_gt_boxes_valid > 0 else 0
|
| 427 |
+
det_f1 = 2 * det_precision * det_recall / (det_precision + det_recall) if (det_precision + det_recall) > 0 else 0
|
| 428 |
+
|
| 429 |
+
return {
|
| 430 |
+
'TP': total_tp,
|
| 431 |
+
'FP': total_fp,
|
| 432 |
+
'FN': total_fn,
|
| 433 |
+
'Precision': precision,
|
| 434 |
+
'Recall': recall,
|
| 435 |
+
'F1-Score': f1_score,
|
| 436 |
+
'OCR_detections': total_ocr_detections,
|
| 437 |
+
'GT_boxes': total_gt_boxes_valid,
|
| 438 |
+
'Detection_rate': total_ocr_detections / total_gt_boxes_valid if total_gt_boxes_valid > 0 else 0,
|
| 439 |
+
# 新增:纯检测指标
|
| 440 |
+
'Det_Precision': det_precision,
|
| 441 |
+
'Det_Recall': det_recall,
|
| 442 |
+
'Det_F1': det_f1,
|
| 443 |
+
'Det_matched': total_details['det_matched'],
|
| 444 |
+
'Text_matched': total_details['text_matched'],
|
| 445 |
+
'eval_mode': eval_mode,
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def main():
|
| 450 |
+
# 自动生成 output 路径:根据 sr_dir 最后一个目录名
|
| 451 |
+
sr_dir = Path(CONFIG['sr_dir'])
|
| 452 |
+
baseline_name = sr_dir.name # 获取最后一个目录名,如 'sr', 'gt', 'bicubic' 等
|
| 453 |
+
output_path = Path(f"./roadtext_eval_results_{baseline_name}.json")
|
| 454 |
+
|
| 455 |
+
print("="*60)
|
| 456 |
+
print("RoadText1K OCR Evaluation")
|
| 457 |
+
print("="*60)
|
| 458 |
+
print(f"SR images: {CONFIG['sr_dir']}")
|
| 459 |
+
print(f"Annotations: {CONFIG['annotation_file']}")
|
| 460 |
+
print(f"OCR engine: {CONFIG['ocr_engine']}")
|
| 461 |
+
print(f"IoU threshold: {CONFIG['iou_threshold']}")
|
| 462 |
+
print(f"Text similarity threshold: {CONFIG['text_similarity_threshold']}")
|
| 463 |
+
print(f"Output will be saved to: {output_path}")
|
| 464 |
+
print()
|
| 465 |
+
|
| 466 |
+
# 加载 OCR 模型
|
| 467 |
+
ocr_model, engine = load_ocr_model(
|
| 468 |
+
CONFIG['ocr_engine'],
|
| 469 |
+
device=CONFIG['device']
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# 评估
|
| 473 |
+
results = evaluate_dataset(
|
| 474 |
+
CONFIG['sr_dir'],
|
| 475 |
+
CONFIG['annotation_file'],
|
| 476 |
+
ocr_model,
|
| 477 |
+
engine,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# 打印结果
|
| 481 |
+
print("\n" + "="*60)
|
| 482 |
+
print("EVALUATION RESULTS")
|
| 483 |
+
print("="*60)
|
| 484 |
+
print(f"评估模式: {results.get('eval_mode', 'end2end')}")
|
| 485 |
+
|
| 486 |
+
print("\n[Detection Statistics]")
|
| 487 |
+
print(f" GT text boxes (valid): {results['GT_boxes']}")
|
| 488 |
+
print(f" OCR detections: {results['OCR_detections']}")
|
| 489 |
+
print(f" Detection matched (IoU): {results.get('Det_matched', 'N/A')}")
|
| 490 |
+
print(f" Text matched: {results.get('Text_matched', 'N/A')}")
|
| 491 |
+
|
| 492 |
+
print("\n[Detection-Only Metrics] (只看框位置,不看文字)")
|
| 493 |
+
print(f" Det Precision: {results.get('Det_Precision', 0)*100:.2f}%")
|
| 494 |
+
print(f" Det Recall: {results.get('Det_Recall', 0)*100:.2f}%")
|
| 495 |
+
print(f" Det F1-Score: {results.get('Det_F1', 0)*100:.2f}%")
|
| 496 |
+
|
| 497 |
+
print("\n[End-to-End Metrics] (检测 + 识别)")
|
| 498 |
+
print(f" True Positives: {results['TP']}")
|
| 499 |
+
print(f" False Positives: {results['FP']}")
|
| 500 |
+
print(f" False Negatives: {results['FN']}")
|
| 501 |
+
print(f" Precision: {results['Precision']*100:.2f}%")
|
| 502 |
+
print(f" Recall: {results['Recall']*100:.2f}%")
|
| 503 |
+
print(f" F1-Score: {results['F1-Score']*100:.2f}%")
|
| 504 |
+
|
| 505 |
+
# 保存结果
|
| 506 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 507 |
+
with open(output_path, 'w') as f:
|
| 508 |
+
json.dump(results, f, indent=2)
|
| 509 |
+
print(f"\nResults saved to {output_path}")
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == '__main__':
|
| 513 |
+
main()
|
| 514 |
+
|
diffusion-dpo-ocr/verify_roadtext_annotations.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
验证 RoadText1K 预处理结果的脚本
|
| 3 |
+
可视化 GT 标注框与 crop 后图像的对应关系,检查坐标是否正确
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import random
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ============================================================================
|
| 18 |
+
# 配置
|
| 19 |
+
# ============================================================================
|
| 20 |
+
CONFIG = {
|
| 21 |
+
# GT 图像目录
|
| 22 |
+
'gt_dir': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/gt',
|
| 23 |
+
|
| 24 |
+
# 标注文件路径
|
| 25 |
+
'annotation_file': '/home/wanghongbo06/baipurui/DATA/RoadText1k_patch_crop/annotations.json',
|
| 26 |
+
|
| 27 |
+
# 可视化输出目录
|
| 28 |
+
'vis_output_dir': './verify_roadtext_vis',
|
| 29 |
+
|
| 30 |
+
# 可视化图片数量
|
| 31 |
+
'num_samples': 20,
|
| 32 |
+
|
| 33 |
+
# 随机种子
|
| 34 |
+
'seed': 42,
|
| 35 |
+
}
|
| 36 |
+
# ============================================================================
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def draw_annotations(img: np.ndarray, ann: Dict, color=(0, 255, 0), thickness=2) -> np.ndarray:
|
| 40 |
+
"""在图像上绘制标注框"""
|
| 41 |
+
img_vis = img.copy()
|
| 42 |
+
|
| 43 |
+
polygons = ann.get('polygons', [])
|
| 44 |
+
texts = ann.get('texts', [])
|
| 45 |
+
ignores = ann.get('ignore', [False] * len(texts))
|
| 46 |
+
|
| 47 |
+
for i, (poly, text, ignore) in enumerate(zip(polygons, texts, ignores)):
|
| 48 |
+
if ignore:
|
| 49 |
+
box_color = (128, 128, 128) # 灰色 = ignore
|
| 50 |
+
else:
|
| 51 |
+
box_color = color
|
| 52 |
+
|
| 53 |
+
# 绘制多边形
|
| 54 |
+
pts = np.array(poly).reshape(-1, 2).astype(np.int32)
|
| 55 |
+
cv2.polylines(img_vis, [pts], True, box_color, thickness)
|
| 56 |
+
|
| 57 |
+
# 标注文本
|
| 58 |
+
x, y = int(pts[0][0]), int(pts[0][1]) - 5
|
| 59 |
+
if y < 15:
|
| 60 |
+
y = int(pts[0][1]) + 15
|
| 61 |
+
|
| 62 |
+
# 缩短文本显示
|
| 63 |
+
display_text = text[:15] + "..." if len(text) > 15 else text
|
| 64 |
+
cv2.putText(img_vis, display_text, (x, y),
|
| 65 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, box_color, 1)
|
| 66 |
+
|
| 67 |
+
return img_vis
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def verify_single_image(
|
| 71 |
+
img_path: Path,
|
| 72 |
+
ann: Dict,
|
| 73 |
+
save_path: Path,
|
| 74 |
+
) -> Dict:
|
| 75 |
+
"""验证单张图片的标注"""
|
| 76 |
+
img = cv2.imread(str(img_path))
|
| 77 |
+
if img is None:
|
| 78 |
+
return {'error': f'Cannot read image: {img_path}'}
|
| 79 |
+
|
| 80 |
+
h, w = img.shape[:2]
|
| 81 |
+
|
| 82 |
+
# 统计信息
|
| 83 |
+
stats = {
|
| 84 |
+
'img_size': (w, h),
|
| 85 |
+
'num_boxes': len(ann.get('polygons', [])),
|
| 86 |
+
'num_ignored': sum(ann.get('ignore', [])),
|
| 87 |
+
'boxes_in_bounds': 0,
|
| 88 |
+
'boxes_out_of_bounds': 0,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# 检查标注框是否在图像范围内
|
| 92 |
+
for poly in ann.get('polygons', []):
|
| 93 |
+
xs = [poly[i] for i in range(0, len(poly), 2)]
|
| 94 |
+
ys = [poly[i] for i in range(1, len(poly), 2)]
|
| 95 |
+
|
| 96 |
+
if min(xs) >= 0 and max(xs) <= w and min(ys) >= 0 and max(ys) <= h:
|
| 97 |
+
stats['boxes_in_bounds'] += 1
|
| 98 |
+
else:
|
| 99 |
+
stats['boxes_out_of_bounds'] += 1
|
| 100 |
+
print(f" [WARNING] Box out of bounds: x=[{min(xs):.1f}, {max(xs):.1f}], y=[{min(ys):.1f}, {max(ys):.1f}], img={w}x{h}")
|
| 101 |
+
|
| 102 |
+
# 绘制标注
|
| 103 |
+
img_vis = draw_annotations(img, ann)
|
| 104 |
+
|
| 105 |
+
# 保存
|
| 106 |
+
cv2.imwrite(str(save_path), img_vis)
|
| 107 |
+
|
| 108 |
+
return stats
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
random.seed(CONFIG['seed'])
|
| 113 |
+
|
| 114 |
+
gt_dir = Path(CONFIG['gt_dir'])
|
| 115 |
+
ann_file = Path(CONFIG['annotation_file'])
|
| 116 |
+
vis_dir = Path(CONFIG['vis_output_dir'])
|
| 117 |
+
vis_dir.mkdir(parents=True, exist_ok=True)
|
| 118 |
+
|
| 119 |
+
print("=" * 60)
|
| 120 |
+
print("RoadText1K 标注验证脚本")
|
| 121 |
+
print("=" * 60)
|
| 122 |
+
print(f"GT dir: {gt_dir}")
|
| 123 |
+
print(f"Annotations: {ann_file}")
|
| 124 |
+
print(f"Output: {vis_dir}")
|
| 125 |
+
print()
|
| 126 |
+
|
| 127 |
+
# 检查文件是否存在
|
| 128 |
+
if not ann_file.exists():
|
| 129 |
+
print(f"Error: Annotation file not found: {ann_file}")
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
if not gt_dir.exists():
|
| 133 |
+
print(f"Error: GT directory not found: {gt_dir}")
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
# 加载标注
|
| 137 |
+
with open(ann_file, 'r', encoding='utf-8') as f:
|
| 138 |
+
annotations = json.load(f)
|
| 139 |
+
|
| 140 |
+
print(f"总共 {len(annotations)} 张图片有标注")
|
| 141 |
+
|
| 142 |
+
# 全局统计
|
| 143 |
+
total_boxes = 0
|
| 144 |
+
total_ignored = 0
|
| 145 |
+
total_in_bounds = 0
|
| 146 |
+
total_out_of_bounds = 0
|
| 147 |
+
|
| 148 |
+
# 计算所有图片的统计
|
| 149 |
+
for img_name, ann in annotations.items():
|
| 150 |
+
total_boxes += len(ann.get('polygons', []))
|
| 151 |
+
total_ignored += sum(ann.get('ignore', []))
|
| 152 |
+
|
| 153 |
+
# 检查边界
|
| 154 |
+
for poly in ann.get('polygons', []):
|
| 155 |
+
xs = [poly[i] for i in range(0, len(poly), 2)]
|
| 156 |
+
ys = [poly[i] for i in range(1, len(poly), 2)]
|
| 157 |
+
if min(xs) >= 0 and max(xs) <= 512 and min(ys) >= 0 and max(ys) <= 512:
|
| 158 |
+
total_in_bounds += 1
|
| 159 |
+
else:
|
| 160 |
+
total_out_of_bounds += 1
|
| 161 |
+
|
| 162 |
+
print(f"\n全局统计:")
|
| 163 |
+
print(f" 总文本框: {total_boxes}")
|
| 164 |
+
print(f" 忽略框: {total_ignored}")
|
| 165 |
+
print(f" 有效框 (非忽略): {total_boxes - total_ignored}")
|
| 166 |
+
print(f" 框在图像范围内: {total_in_bounds}")
|
| 167 |
+
print(f" 框超出范围: {total_out_of_bounds}")
|
| 168 |
+
|
| 169 |
+
if total_out_of_bounds > 0:
|
| 170 |
+
print(f"\n⚠️ 有 {total_out_of_bounds} 个框超出图像范围!这可能是坐标转换问题。")
|
| 171 |
+
|
| 172 |
+
# 随机选取一些图片进行可视化
|
| 173 |
+
img_names = list(annotations.keys())
|
| 174 |
+
num_samples = min(CONFIG['num_samples'], len(img_names))
|
| 175 |
+
selected = random.sample(img_names, num_samples)
|
| 176 |
+
|
| 177 |
+
print(f"\n随机选取 {num_samples} 张图片进行可视化...")
|
| 178 |
+
|
| 179 |
+
for img_name in selected:
|
| 180 |
+
ann = annotations[img_name]
|
| 181 |
+
img_path = gt_dir / img_name
|
| 182 |
+
|
| 183 |
+
if not img_path.exists():
|
| 184 |
+
# 尝试其他扩展名
|
| 185 |
+
img_path = gt_dir / (Path(img_name).stem + '.jpg')
|
| 186 |
+
|
| 187 |
+
if not img_path.exists():
|
| 188 |
+
print(f" [SKIP] Image not found: {img_name}")
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
save_path = vis_dir / f"verify_{img_name}"
|
| 192 |
+
stats = verify_single_image(img_path, ann, save_path)
|
| 193 |
+
|
| 194 |
+
if 'error' not in stats:
|
| 195 |
+
print(f" [OK] {img_name}: {stats['num_boxes']} boxes, "
|
| 196 |
+
f"{stats['boxes_out_of_bounds']} out of bounds")
|
| 197 |
+
|
| 198 |
+
print(f"\n可视化结果保存到: {vis_dir}")
|
| 199 |
+
|
| 200 |
+
# 额外检查:打印一些标注样例
|
| 201 |
+
print("\n" + "=" * 60)
|
| 202 |
+
print("标注样例 (前 3 张图):")
|
| 203 |
+
print("=" * 60)
|
| 204 |
+
|
| 205 |
+
for i, (img_name, ann) in enumerate(list(annotations.items())[:3]):
|
| 206 |
+
print(f"\n[{i+1}] {img_name}")
|
| 207 |
+
print(f" crop_position: {ann.get('crop_position', 'N/A')}")
|
| 208 |
+
print(f" num_boxes: {len(ann.get('polygons', []))}")
|
| 209 |
+
|
| 210 |
+
for j, (poly, text, ignore) in enumerate(zip(
|
| 211 |
+
ann.get('polygons', [])[:2], # 只显示前2个
|
| 212 |
+
ann.get('texts', [])[:2],
|
| 213 |
+
ann.get('ignore', [])[:2]
|
| 214 |
+
)):
|
| 215 |
+
xs = [poly[k] for k in range(0, len(poly), 2)]
|
| 216 |
+
ys = [poly[k] for k in range(1, len(poly), 2)]
|
| 217 |
+
print(f" Box {j}: x=[{min(xs):.1f}, {max(xs):.1f}], y=[{min(ys):.1f}, {max(ys):.1f}], "
|
| 218 |
+
f"text='{text[:20]}', ignore={ignore}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == '__main__':
|
| 222 |
+
main()
|
| 223 |
+
|
diffusion-dpo-test/DIAGNOSTIC_CHECKLIST.md
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔍 完整诊断检查清单
|
| 2 |
+
|
| 3 |
+
## 问题描述
|
| 4 |
+
训练了 15 和 60 个 epoch,测试结果**逐像素完全相同**(262144 个 RGB 值一模一样)
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## ✅ 已确认正常的部分
|
| 9 |
+
|
| 10 |
+
### 1. 训练动态正常
|
| 11 |
+
- ✅ `acc` 和 `l_acc` 在变化
|
| 12 |
+
- ✅ `loss` 在下降
|
| 13 |
+
- ✅ `Max lora_B Check` 从 `1.04e-05` 增长到 `2.23e-03`(200+ 倍增长)
|
| 14 |
+
|
| 15 |
+
### 2. 代码逻辑正常
|
| 16 |
+
- ✅ `disable_adapter()` 使用了 `with` 语句(已修复)
|
| 17 |
+
- ✅ VAE 编码在 `no_grad()` 中
|
| 18 |
+
- ✅ `ref_pred` 正确 detach
|
| 19 |
+
- ✅ LoRA 键名已清理(去除 `base_model.model.` 前缀)
|
| 20 |
+
- ✅ 优化器在正确的时机创建(dtype 转换之后)
|
| 21 |
+
|
| 22 |
+
### 3. x_embedder 为 0 是正常的
|
| 23 |
+
- ✅ 输入层通常不需要针对 DPO 偏好优化
|
| 24 |
+
- ✅ 其他深层(`single_transformer_blocks.*`)的权重在增长
|
| 25 |
+
- ✅ LoRA 是加法,不是乘法,0 不会阻塞梯度流
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## ❓ 需要检查的部分
|
| 30 |
+
|
| 31 |
+
### 🎯 检查 1: Checkpoint 权重是否真的在变化?
|
| 32 |
+
|
| 33 |
+
**这是最关键的检查!**
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
# 在训练机器上运行
|
| 37 |
+
cd <训练脚本运行目录> # 即 run.sh 的执行目录
|
| 38 |
+
|
| 39 |
+
# 1. 确认文件存在
|
| 40 |
+
ls -lh results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors
|
| 41 |
+
ls -lh results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
|
| 42 |
+
|
| 43 |
+
# 2. 对比权重
|
| 44 |
+
python compare_checkpoints.py \
|
| 45 |
+
results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors \
|
| 46 |
+
results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
**预期结果**:
|
| 50 |
+
- ✅ **正常**: "50%+ 的参数在变化",最大差异 > 1e-4
|
| 51 |
+
- ❌ **异常**: "完全相同" 或 "< 10% 参数变化"
|
| 52 |
+
|
| 53 |
+
**如果异常,说明**:
|
| 54 |
+
- 保存逻辑有问题
|
| 55 |
+
- 或者训练根本没生效(但这与 loss 下降矛盾)
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
### 🎯 检查 2: 测试代码加载的是哪个 checkpoint?
|
| 60 |
+
|
| 61 |
+
**问题**: 测试代码路径和训练输出路径不匹配
|
| 62 |
+
|
| 63 |
+
训练输出:
|
| 64 |
+
```
|
| 65 |
+
results_1202_4/checkpoint-XX/lora_train_unet/adapter_model.safetensors
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
测试代码加载:
|
| 69 |
+
```python
|
| 70 |
+
# diffusion-dpo-test/test.py line 33
|
| 71 |
+
"/home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/..."
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
注意 `results_1130` vs `results_1202_4`!
|
| 75 |
+
|
| 76 |
+
**验证方法**:
|
| 77 |
+
```bash
|
| 78 |
+
# 在测试机器上
|
| 79 |
+
# 1. 确认测试代码实际加载的文件
|
| 80 |
+
ls -lh /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/lora_train_unet/adapter_model.safetensors
|
| 81 |
+
ls -lh /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-500/lora_train_unet/adapter_model.safetensors
|
| 82 |
+
|
| 83 |
+
# 2. 检查文件修改时间(是否是最新训练的)
|
| 84 |
+
stat /home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/lora_train_unet/adapter_model.safetensors
|
| 85 |
+
|
| 86 |
+
# 3. 对比这两个文件
|
| 87 |
+
python compare_checkpoints.py \
|
| 88 |
+
/home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-50/lora_train_unet/adapter_model.safetensors \
|
| 89 |
+
/home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-500/lora_train_unet/adapter_model.safetensors
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
**如果这两个文件完全相同**:
|
| 93 |
+
- 说明您测试的不是最新训练的模型
|
| 94 |
+
- 需要更新测试代码的路径,或者将最新的 checkpoint 复制到测试机器
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
### 🎯 检查 3: 测试代码的 LoRA 融合逻辑
|
| 99 |
+
|
| 100 |
+
测试代码使用了**两次 fuse_lora**:
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
# 第一次:SR base LoRA
|
| 104 |
+
pipe.load_lora_weights("...pytorch_lora_weights_v2.safetensors", adapter_name="sr")
|
| 105 |
+
pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr"])
|
| 106 |
+
pipe.unload_lora_weights()
|
| 107 |
+
|
| 108 |
+
# 第二次:DPO trained LoRA
|
| 109 |
+
pipe.load_lora_weights("...adapter_model.safetensors", adapter_name="sr2")
|
| 110 |
+
pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr2"])
|
| 111 |
+
pipe.unload_lora_weights()
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
**潜在问题**:
|
| 115 |
+
1. `unload_lora_weights()` 可能清除了刚融合的权重
|
| 116 |
+
2. 第二次 `fuse_lora` 可能没有正确叠加到第一次的结果上
|
| 117 |
+
3. 如果两个 LoRA 的权重完全相同,输出自然也相同
|
| 118 |
+
|
| 119 |
+
**验证方法 A: 只加载 DPO LoRA**
|
| 120 |
+
```python
|
| 121 |
+
# 临时修改 test.py
|
| 122 |
+
pipe = FluxPipeline.from_pretrained(...).to("cuda")
|
| 123 |
+
|
| 124 |
+
# 只加载第一个 LoRA
|
| 125 |
+
pipe.load_lora_weights(
|
| 126 |
+
"/home/wanghongbo06/baipurui/CKPTs/FLUX_SR/pytorch_lora_weights_v2.safetensors",
|
| 127 |
+
adapter_name="sr"
|
| 128 |
+
)
|
| 129 |
+
pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr"])
|
| 130 |
+
pipe.unload_lora_weights()
|
| 131 |
+
|
| 132 |
+
# 生成图像 -> 保存为 result_sr_only.png
|
| 133 |
+
|
| 134 |
+
# 然后只加载第二个 LoRA
|
| 135 |
+
pipe2 = FluxPipeline.from_pretrained(...).to("cuda")
|
| 136 |
+
pipe2.load_lora_weights(
|
| 137 |
+
"/home/wanghongbo06/diffusion-dpo-adv/results_1130/checkpoint-60/lora_train_unet/adapter_model.safetensors",
|
| 138 |
+
adapter_name="dpo"
|
| 139 |
+
)
|
| 140 |
+
pipe2.fuse_lora(lora_scale=1.0, adapter_names=["dpo"])
|
| 141 |
+
|
| 142 |
+
# 生成图像 -> 保存为 result_dpo_only.png
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
如果 `result_dpo_only.png` 在不同 epoch 之间也完全相同,说明问题在 checkpoint 本身。
|
| 146 |
+
|
| 147 |
+
**验证方法 B: 检查融合后的权重**
|
| 148 |
+
```python
|
| 149 |
+
# 在 test.py 中添加调试代码
|
| 150 |
+
import torch
|
| 151 |
+
|
| 152 |
+
# 加载第一个 LoRA 后
|
| 153 |
+
pipe.load_lora_weights("...pytorch_lora_weights_v2.safetensors", adapter_name="sr")
|
| 154 |
+
pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr"])
|
| 155 |
+
|
| 156 |
+
# 保存融合后的 transformer 权重快照
|
| 157 |
+
transformer_weights_after_sr = {
|
| 158 |
+
k: v.clone() for k, v in pipe.transformer.state_dict().items()
|
| 159 |
+
if 'single_transformer_blocks.30' in k
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
pipe.unload_lora_weights()
|
| 163 |
+
|
| 164 |
+
# 加载第二个 LoRA 后
|
| 165 |
+
pipe.load_lora_weights("...adapter_model.safetensors", adapter_name="sr2")
|
| 166 |
+
pipe.fuse_lora(lora_scale=1.0, adapter_names=["sr2"])
|
| 167 |
+
|
| 168 |
+
transformer_weights_after_dpo = {
|
| 169 |
+
k: v.clone() for k, v in pipe.transformer.state_dict().items()
|
| 170 |
+
if 'single_transformer_blocks.30' in k
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# 对比
|
| 174 |
+
for k in transformer_weights_after_sr.keys():
|
| 175 |
+
diff = (transformer_weights_after_dpo[k] - transformer_weights_after_sr[k]).abs().max()
|
| 176 |
+
print(f"{k}: max_diff = {diff:.6e}")
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
如果所有 diff 都是 0,说明第二个 LoRA 没有被正确应用。
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
### 🎯 检查 4: Checkpoint 文件的完整性
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
# 在训练机器上
|
| 187 |
+
python inspect_safetensor.py results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
**预期结果**:
|
| 191 |
+
- 应该有 100+ 个参数张量
|
| 192 |
+
- `single_transformer_blocks.*` 的 lora_B 应该有非零值
|
| 193 |
+
- 非零参数比例应该 > 50%
|
| 194 |
+
|
| 195 |
+
**如果异常**:
|
| 196 |
+
- 文件损坏
|
| 197 |
+
- 或者保存逻辑有问题
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
## 🔧 可能的修复方案
|
| 202 |
+
|
| 203 |
+
### 方案 1: 如果 checkpoint 权重没有变化
|
| 204 |
+
**原因**: 保存逻辑有问题
|
| 205 |
+
|
| 206 |
+
**修复**: 检查 `save_model_hook` 中的状态字典获取逻辑
|
| 207 |
+
|
| 208 |
+
```python
|
| 209 |
+
# train_single_lora.py line 672
|
| 210 |
+
full_state_dict = accelerator.get_state_dict(model)
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
可能需要改为:
|
| 214 |
+
```python
|
| 215 |
+
from peft import get_peft_model_state_dict
|
| 216 |
+
full_state_dict = get_peft_model_state_dict(model, adapter_name="train_unet")
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
---
|
| 220 |
+
|
| 221 |
+
### 方案 2: 如果测试代码路径错误
|
| 222 |
+
**原因**: 加载了旧的 checkpoint
|
| 223 |
+
|
| 224 |
+
**修复**: 更新测试代码的路径,或者将最新 checkpoint 复制到测试机器
|
| 225 |
+
|
| 226 |
+
```bash
|
| 227 |
+
# 在训练机器上
|
| 228 |
+
scp -r results_1202_4/checkpoint-60 user@test-machine:/home/wanghongbo06/diffusion-dpo-adv/
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
### 方案 3: 如果 LoRA 融合有问题
|
| 234 |
+
**原因**: `fuse_lora` 和 `unload_lora_weights` 的交互有问题
|
| 235 |
+
|
| 236 |
+
**修复**: 使用 `set_adapters` 而不是 `fuse_lora`
|
| 237 |
+
|
| 238 |
+
```python
|
| 239 |
+
# 新的测试代码
|
| 240 |
+
pipe = FluxPipeline.from_pretrained(...).to("cuda")
|
| 241 |
+
|
| 242 |
+
# 加载两个 LoRA(不融合)
|
| 243 |
+
pipe.load_lora_weights("...pytorch_lora_weights_v2.safetensors", adapter_name="sr")
|
| 244 |
+
pipe.load_lora_weights("...adapter_model.safetensors", adapter_name="dpo")
|
| 245 |
+
|
| 246 |
+
# 同时启用两个 adapter
|
| 247 |
+
pipe.set_adapters(["sr", "dpo"], adapter_weights=[1.0, 1.0])
|
| 248 |
+
|
| 249 |
+
# 生成图像
|
| 250 |
+
result_img = generate(pipe, ...)
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
---
|
| 254 |
+
|
| 255 |
+
## 📊 诊断流程图
|
| 256 |
+
|
| 257 |
+
```
|
| 258 |
+
开始
|
| 259 |
+
↓
|
| 260 |
+
检查 1: 对比 checkpoint 权重
|
| 261 |
+
├─ 权重有变化 → 继续检查 2
|
| 262 |
+
└─ 权重无变化 → 【问题在训练/保存】→ 方案 1
|
| 263 |
+
↓
|
| 264 |
+
检查 2: 确认测试代码加载的路径
|
| 265 |
+
├─ 路径正确 → 继续检查 3
|
| 266 |
+
└─ 路径错误 → 【问题在测试代码】→ 方案 2
|
| 267 |
+
↓
|
| 268 |
+
检查 3: 验证 LoRA 融合逻辑
|
| 269 |
+
├─ 融合正确 → 继续检查 4
|
| 270 |
+
└─ 融合失败 → 【问题在测试代码】→ 方案 3
|
| 271 |
+
↓
|
| 272 |
+
检查 4: 检查 checkpoint 文件完整性
|
| 273 |
+
├─ 文件完整 → 【未知问题,需要更深入调查】
|
| 274 |
+
└─ 文件损坏 → 【问题在保存】→ 方案 1
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
## 🚀 立即执行
|
| 280 |
+
|
| 281 |
+
**请在训练机器上运行以下命令**:
|
| 282 |
+
|
| 283 |
+
```bash
|
| 284 |
+
cd <run.sh 的执行目录>
|
| 285 |
+
|
| 286 |
+
# 1. 对比 checkpoint(最重要!)
|
| 287 |
+
python /data2/hongbo.wang/DPO-SR/compare_checkpoints.py \
|
| 288 |
+
results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors \
|
| 289 |
+
results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
|
| 290 |
+
|
| 291 |
+
# 2. 检查单个 checkpoint
|
| 292 |
+
python /data2/hongbo.wang/DPO-SR/inspect_safetensor.py \
|
| 293 |
+
results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
**请将输出结果发给我,我会根据结果进一步诊断!**
|
| 297 |
+
|
diffusion-dpo-test/DIV2K-val/sobolev-400/0000843-seed-0.png
ADDED
|
diffusion-dpo-test/__pycache__/color_fix.cpython-310.pyc
ADDED
|
Binary file (3.7 kB). View file
|
|
|
diffusion-dpo-test/analyze_lora_magnitude.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
分析 LoRA 权重的实际大小,理解为什么效果差异这么大
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
os.environ["HF_HOME"] = "/home/wanghongbo06/.cache/huggingface"
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
def analyze_lora(path, name=""):
|
| 13 |
+
"""分析单个 LoRA 文件"""
|
| 14 |
+
print(f"\n{'='*80}")
|
| 15 |
+
print(f"分析 LoRA: {name}")
|
| 16 |
+
print(f"路径: {path}")
|
| 17 |
+
print(f"{'='*80}")
|
| 18 |
+
|
| 19 |
+
state_dict = load_file(path)
|
| 20 |
+
|
| 21 |
+
# 分离 lora_A 和 lora_B
|
| 22 |
+
lora_a_keys = [k for k in state_dict.keys() if "lora_A" in k]
|
| 23 |
+
lora_b_keys = [k for k in state_dict.keys() if "lora_B" in k]
|
| 24 |
+
|
| 25 |
+
print(f"\n总参数数: {len(state_dict)}")
|
| 26 |
+
print(f" - lora_A: {len(lora_a_keys)} 个")
|
| 27 |
+
print(f" - lora_B: {len(lora_b_keys)} 个")
|
| 28 |
+
|
| 29 |
+
# 分析 lora_A
|
| 30 |
+
print(f"\n--- lora_A 统计 ---")
|
| 31 |
+
a_means = []
|
| 32 |
+
a_maxs = []
|
| 33 |
+
a_stds = []
|
| 34 |
+
for k in lora_a_keys:
|
| 35 |
+
v = state_dict[k].float()
|
| 36 |
+
a_means.append(v.mean().item())
|
| 37 |
+
a_maxs.append(v.abs().max().item())
|
| 38 |
+
a_stds.append(v.std().item())
|
| 39 |
+
|
| 40 |
+
print(f" Mean of means: {np.mean(a_means):.6e}")
|
| 41 |
+
print(f" Max of maxs: {np.max(a_maxs):.6e}")
|
| 42 |
+
print(f" Mean of stds: {np.mean(a_stds):.6e}")
|
| 43 |
+
|
| 44 |
+
# 分析 lora_B
|
| 45 |
+
print(f"\n--- lora_B 统计 ---")
|
| 46 |
+
b_means = []
|
| 47 |
+
b_maxs = []
|
| 48 |
+
b_stds = []
|
| 49 |
+
b_nonzero_ratio = []
|
| 50 |
+
for k in lora_b_keys:
|
| 51 |
+
v = state_dict[k].float()
|
| 52 |
+
b_means.append(v.mean().item())
|
| 53 |
+
b_maxs.append(v.abs().max().item())
|
| 54 |
+
b_stds.append(v.std().item())
|
| 55 |
+
b_nonzero_ratio.append((v.abs() > 1e-10).float().mean().item())
|
| 56 |
+
|
| 57 |
+
print(f" Mean of means: {np.mean(b_means):.6e}")
|
| 58 |
+
print(f" Max of maxs: {np.max(b_maxs):.6e}")
|
| 59 |
+
print(f" Mean of stds: {np.mean(b_stds):.6e}")
|
| 60 |
+
print(f" Avg non-zero ratio: {np.mean(b_nonzero_ratio)*100:.2f}%")
|
| 61 |
+
|
| 62 |
+
# LoRA 的实际贡献 = A @ B
|
| 63 |
+
# 估算 LoRA 对权重的影响大小
|
| 64 |
+
print(f"\n--- LoRA 影响估算 ---")
|
| 65 |
+
print(f" LoRA 输出 ≈ x @ A.T @ B.T")
|
| 66 |
+
print(f" |A| * |B| ≈ {np.max(a_maxs) * np.max(b_maxs):.6e}")
|
| 67 |
+
|
| 68 |
+
# 找出最大的几个 lora_B
|
| 69 |
+
print(f"\n--- 最大的 5 个 lora_B 层 ---")
|
| 70 |
+
b_with_max = [(k, state_dict[k].float().abs().max().item()) for k in lora_b_keys]
|
| 71 |
+
b_with_max.sort(key=lambda x: -x[1])
|
| 72 |
+
for k, v in b_with_max[:5]:
|
| 73 |
+
print(f" {v:.6e}: {k}")
|
| 74 |
+
|
| 75 |
+
return state_dict
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def compare_loras(path1, path2, name1="LoRA 1", name2="LoRA 2"):
|
| 79 |
+
"""对比两个 LoRA"""
|
| 80 |
+
print(f"\n{'='*80}")
|
| 81 |
+
print(f"对比 {name1} vs {name2}")
|
| 82 |
+
print(f"{'='*80}")
|
| 83 |
+
|
| 84 |
+
sd1 = load_file(path1)
|
| 85 |
+
sd2 = load_file(path2)
|
| 86 |
+
|
| 87 |
+
# 对比 lora_B 的变化
|
| 88 |
+
lora_b_keys = [k for k in sd1.keys() if "lora_B" in k]
|
| 89 |
+
|
| 90 |
+
diffs = []
|
| 91 |
+
for k in lora_b_keys:
|
| 92 |
+
if k in sd2:
|
| 93 |
+
diff = (sd2[k].float() - sd1[k].float()).abs()
|
| 94 |
+
diffs.append({
|
| 95 |
+
'key': k,
|
| 96 |
+
'max_diff': diff.max().item(),
|
| 97 |
+
'mean_diff': diff.mean().item(),
|
| 98 |
+
'val1_max': sd1[k].float().abs().max().item(),
|
| 99 |
+
'val2_max': sd2[k].float().abs().max().item(),
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
# 按差异排序
|
| 103 |
+
diffs.sort(key=lambda x: -x['max_diff'])
|
| 104 |
+
|
| 105 |
+
print(f"\n变化最大的 10 个 lora_B 层:")
|
| 106 |
+
print("-" * 100)
|
| 107 |
+
for d in diffs[:10]:
|
| 108 |
+
print(f" max_diff={d['max_diff']:.6e}, {name1}_max={d['val1_max']:.6e}, {name2}_max={d['val2_max']:.6e}")
|
| 109 |
+
print(f" {d['key']}")
|
| 110 |
+
|
| 111 |
+
# 总体统计
|
| 112 |
+
all_max_diffs = [d['max_diff'] for d in diffs]
|
| 113 |
+
print(f"\n总体统计:")
|
| 114 |
+
print(f" 最大差异: {max(all_max_diffs):.6e}")
|
| 115 |
+
print(f" 平均差异: {np.mean(all_max_diffs):.6e}")
|
| 116 |
+
print(f" 差异 > 1e-4 的层数: {sum(1 for d in all_max_diffs if d > 1e-4)}")
|
| 117 |
+
print(f" 差异 > 1e-5 的层数: {sum(1 for d in all_max_diffs if d > 1e-5)}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def compare_with_sr_lora(sr_path, dpo_path):
|
| 121 |
+
"""对比 SR LoRA 和 DPO LoRA 的量级"""
|
| 122 |
+
print(f"\n{'='*80}")
|
| 123 |
+
print(f"对比 SR LoRA 和 DPO LoRA 的量级")
|
| 124 |
+
print(f"{'='*80}")
|
| 125 |
+
|
| 126 |
+
sr_sd = load_file(sr_path)
|
| 127 |
+
dpo_sd = load_file(dpo_path)
|
| 128 |
+
|
| 129 |
+
# SR LoRA 的量级
|
| 130 |
+
sr_b_maxs = []
|
| 131 |
+
for k, v in sr_sd.items():
|
| 132 |
+
if "lora_B" in k or "lora_down" in k: # 不同格式可能用不同命名
|
| 133 |
+
sr_b_maxs.append(v.float().abs().max().item())
|
| 134 |
+
|
| 135 |
+
# DPO LoRA 的量级
|
| 136 |
+
dpo_b_maxs = []
|
| 137 |
+
for k, v in dpo_sd.items():
|
| 138 |
+
if "lora_B" in k:
|
| 139 |
+
dpo_b_maxs.append(v.float().abs().max().item())
|
| 140 |
+
|
| 141 |
+
print(f"\nSR LoRA (lora_B/lora_down):")
|
| 142 |
+
if sr_b_maxs:
|
| 143 |
+
print(f" Max: {max(sr_b_maxs):.6e}")
|
| 144 |
+
print(f" Mean: {np.mean(sr_b_maxs):.6e}")
|
| 145 |
+
else:
|
| 146 |
+
print(f" (没有找到 lora_B 或 lora_down)")
|
| 147 |
+
# 打印所有 key 看看格式
|
| 148 |
+
print(f" SR LoRA keys 示例: {list(sr_sd.keys())[:5]}")
|
| 149 |
+
|
| 150 |
+
print(f"\nDPO LoRA (lora_B):")
|
| 151 |
+
print(f" Max: {max(dpo_b_maxs):.6e}")
|
| 152 |
+
print(f" Mean: {np.mean(dpo_b_maxs):.6e}")
|
| 153 |
+
|
| 154 |
+
if sr_b_maxs and dpo_b_maxs:
|
| 155 |
+
ratio = max(dpo_b_maxs) / max(sr_b_maxs) if max(sr_b_maxs) > 0 else float('inf')
|
| 156 |
+
print(f"\n量级比较:")
|
| 157 |
+
print(f" DPO / SR = {ratio:.4f}")
|
| 158 |
+
if ratio > 1:
|
| 159 |
+
print(f" ⚠️ DPO LoRA 比 SR LoRA 大 {ratio:.1f} 倍!这可能导致效果变差")
|
| 160 |
+
else:
|
| 161 |
+
print(f" DPO LoRA 比 SR LoRA 小 {1/ratio:.1f} 倍")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
# 分析 DPO LoRA checkpoints
|
| 166 |
+
ckpt_15 = "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors"
|
| 167 |
+
ckpt_105 = "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-105/lora_train_unet/adapter_model.safetensors"
|
| 168 |
+
sr_lora = "/home/wanghongbo06/baipurui/CKPTs/FLUX_SR/pytorch_lora_weights_v2.safetensors"
|
| 169 |
+
|
| 170 |
+
# 1. 分析各个 checkpoint
|
| 171 |
+
analyze_lora(ckpt_15, "DPO checkpoint-15")
|
| 172 |
+
analyze_lora(ckpt_105, "DPO checkpoint-105")
|
| 173 |
+
|
| 174 |
+
# 2. 对比 15 vs 105
|
| 175 |
+
compare_loras(ckpt_15, ckpt_105, "ckpt-15", "ckpt-105")
|
| 176 |
+
|
| 177 |
+
# 3. 对比 SR LoRA 和 DPO LoRA 的量级
|
| 178 |
+
compare_with_sr_lora(sr_lora, ckpt_15)
|
| 179 |
+
|
diffusion-dpo-test/check_lora_keys.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""检查 LoRA safetensor 文件的 key 格式"""
|
| 3 |
+
from safetensors.torch import load_file
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
def check_keys(path):
|
| 7 |
+
print(f"\n检查文件: {path}")
|
| 8 |
+
print("=" * 80)
|
| 9 |
+
|
| 10 |
+
state_dict = load_file(path)
|
| 11 |
+
|
| 12 |
+
print(f"总共 {len(state_dict)} 个 key\n")
|
| 13 |
+
|
| 14 |
+
print("前 20 个 key:")
|
| 15 |
+
print("-" * 80)
|
| 16 |
+
for i, key in enumerate(sorted(state_dict.keys())[:20]):
|
| 17 |
+
print(f" {key}")
|
| 18 |
+
|
| 19 |
+
print("\n..." if len(state_dict) > 20 else "")
|
| 20 |
+
|
| 21 |
+
# 检查 key 格式
|
| 22 |
+
print("\nKey 格式分析:")
|
| 23 |
+
print("-" * 80)
|
| 24 |
+
|
| 25 |
+
has_transformer_prefix = any(k.startswith("transformer.") for k in state_dict.keys())
|
| 26 |
+
has_base_model_prefix = any(k.startswith("base_model.") for k in state_dict.keys())
|
| 27 |
+
has_lora_A = any("lora_A" in k for k in state_dict.keys())
|
| 28 |
+
has_lora_B = any("lora_B" in k for k in state_dict.keys())
|
| 29 |
+
has_train_unet = any("train_unet" in k for k in state_dict.keys())
|
| 30 |
+
|
| 31 |
+
print(f" 包含 'transformer.' 前缀: {has_transformer_prefix}")
|
| 32 |
+
print(f" 包含 'base_model.' 前缀: {has_base_model_prefix}")
|
| 33 |
+
print(f" 包含 'lora_A': {has_lora_A}")
|
| 34 |
+
print(f" 包含 'lora_B': {has_lora_B}")
|
| 35 |
+
print(f" 包含 'train_unet': {has_train_unet}")
|
| 36 |
+
|
| 37 |
+
# 显示一个完整的 key 示例
|
| 38 |
+
sample_key = list(state_dict.keys())[0]
|
| 39 |
+
print(f"\n示例 key: {sample_key}")
|
| 40 |
+
print(f"示例 shape: {state_dict[sample_key].shape}")
|
| 41 |
+
|
| 42 |
+
# 检查 Diffusers 期望的格式
|
| 43 |
+
print("\n" + "=" * 80)
|
| 44 |
+
print("Diffusers load_lora_weights 期望的 key 格式:")
|
| 45 |
+
print("-" * 80)
|
| 46 |
+
print(" transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight")
|
| 47 |
+
print(" transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight")
|
| 48 |
+
print("\n您的 key 格式:")
|
| 49 |
+
print(f" {sample_key}")
|
| 50 |
+
|
| 51 |
+
# 判断是否兼容
|
| 52 |
+
print("\n" + "=" * 80)
|
| 53 |
+
if has_train_unet:
|
| 54 |
+
print("❌ 问题: 您的 key 包含 '.train_unet.' 后缀!")
|
| 55 |
+
print(" Diffusers 期望: xxx.lora_A.weight")
|
| 56 |
+
print(" 您的格式: xxx.lora_A.train_unet.weight")
|
| 57 |
+
print("\n 这就是 LoRA 无法加载的原因!")
|
| 58 |
+
elif not has_transformer_prefix:
|
| 59 |
+
print("⚠️ 问题: 您的 key 缺少 'transformer.' 前缀!")
|
| 60 |
+
print(" Diffusers 期望: transformer.xxx.lora_A.weight")
|
| 61 |
+
print(f" 您的格式: {sample_key}")
|
| 62 |
+
else:
|
| 63 |
+
print("✅ Key 格式看起来正确")
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
paths = [
|
| 67 |
+
"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors",
|
| 68 |
+
"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
for path in paths:
|
| 72 |
+
try:
|
| 73 |
+
check_keys(path)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"❌ 无法读取 {path}: {e}")
|
| 76 |
+
|
diffusion-dpo-test/color_fix.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
# --------------------------------------------------------------------------------
|
| 3 |
+
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
|
| 4 |
+
# --------------------------------------------------------------------------------
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
| 13 |
+
|
| 14 |
+
def adain_color_fix(target: Image, source: Image):
|
| 15 |
+
# Convert images to tensors
|
| 16 |
+
to_tensor = ToTensor()
|
| 17 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
| 18 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
| 19 |
+
|
| 20 |
+
# Apply adaptive instance normalization
|
| 21 |
+
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
| 22 |
+
|
| 23 |
+
# Convert tensor back to image
|
| 24 |
+
to_image = ToPILImage()
|
| 25 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
| 26 |
+
|
| 27 |
+
return result_image
|
| 28 |
+
|
| 29 |
+
def wavelet_color_fix(target: Image, source: Image):
|
| 30 |
+
# Convert images to tensors
|
| 31 |
+
to_tensor = ToTensor()
|
| 32 |
+
target_tensor = to_tensor(target).unsqueeze(0)
|
| 33 |
+
source_tensor = to_tensor(source).unsqueeze(0)
|
| 34 |
+
|
| 35 |
+
# Apply wavelet reconstruction
|
| 36 |
+
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
| 37 |
+
|
| 38 |
+
# Convert tensor back to image
|
| 39 |
+
to_image = ToPILImage()
|
| 40 |
+
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
| 41 |
+
|
| 42 |
+
return result_image
|
| 43 |
+
|
| 44 |
+
def calc_mean_std(feat: Tensor, eps=1e-5):
|
| 45 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
| 46 |
+
Args:
|
| 47 |
+
feat (Tensor): 4D tensor.
|
| 48 |
+
eps (float): A small value added to the variance to avoid
|
| 49 |
+
divide-by-zero. Default: 1e-5.
|
| 50 |
+
"""
|
| 51 |
+
size = feat.size()
|
| 52 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
| 53 |
+
b, c = size[:2]
|
| 54 |
+
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
| 55 |
+
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
| 56 |
+
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
| 57 |
+
return feat_mean, feat_std
|
| 58 |
+
|
| 59 |
+
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
| 60 |
+
"""Adaptive instance normalization.
|
| 61 |
+
Adjust the reference features to have the similar color and illuminations
|
| 62 |
+
as those in the degradate features.
|
| 63 |
+
Args:
|
| 64 |
+
content_feat (Tensor): The reference feature.
|
| 65 |
+
style_feat (Tensor): The degradate features.
|
| 66 |
+
"""
|
| 67 |
+
size = content_feat.size()
|
| 68 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
| 69 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
| 70 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 71 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 72 |
+
|
| 73 |
+
def wavelet_blur(image: Tensor, radius: int):
|
| 74 |
+
"""
|
| 75 |
+
Apply wavelet blur to the input tensor.
|
| 76 |
+
"""
|
| 77 |
+
# input shape: (1, 3, H, W)
|
| 78 |
+
# convolution kernel
|
| 79 |
+
kernel_vals = [
|
| 80 |
+
[0.0625, 0.125, 0.0625],
|
| 81 |
+
[0.125, 0.25, 0.125],
|
| 82 |
+
[0.0625, 0.125, 0.0625],
|
| 83 |
+
]
|
| 84 |
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
| 85 |
+
# add channel dimensions to the kernel to make it a 4D tensor
|
| 86 |
+
kernel = kernel[None, None]
|
| 87 |
+
# repeat the kernel across all input channels
|
| 88 |
+
kernel = kernel.repeat(3, 1, 1, 1)
|
| 89 |
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
| 90 |
+
# apply convolution
|
| 91 |
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
| 92 |
+
return output
|
| 93 |
+
|
| 94 |
+
def wavelet_decomposition(image: Tensor, levels=5):
|
| 95 |
+
"""
|
| 96 |
+
Apply wavelet decomposition to the input tensor.
|
| 97 |
+
This function only returns the low frequency & the high frequency.
|
| 98 |
+
"""
|
| 99 |
+
high_freq = torch.zeros_like(image)
|
| 100 |
+
for i in range(levels):
|
| 101 |
+
radius = 2 ** i
|
| 102 |
+
low_freq = wavelet_blur(image, radius)
|
| 103 |
+
high_freq += (image - low_freq)
|
| 104 |
+
image = low_freq
|
| 105 |
+
|
| 106 |
+
return high_freq, low_freq
|
| 107 |
+
|
| 108 |
+
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
| 109 |
+
"""
|
| 110 |
+
Apply wavelet decomposition, so that the content will have the same color as the style.
|
| 111 |
+
"""
|
| 112 |
+
# calculate the wavelet decomposition of the content feature
|
| 113 |
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
| 114 |
+
del content_low_freq
|
| 115 |
+
# calculate the wavelet decomposition of the style feature
|
| 116 |
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
| 117 |
+
del style_high_freq
|
| 118 |
+
# reconstruct the content feature with the style's high frequency
|
| 119 |
+
return content_high_freq + style_low_freq
|
diffusion-dpo-test/compare.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compare_images(img_path1, img_path2, diff_output_path=None):
|
| 6 |
+
"""
|
| 7 |
+
逐像素比较两张 RGB 图像的差异。
|
| 8 |
+
要求两张图分辨率相同、都是 RGB 模式。
|
| 9 |
+
|
| 10 |
+
:param img_path1: 第一张图片路径
|
| 11 |
+
:param img_path2: 第二张图片路径
|
| 12 |
+
:param diff_output_path: 如果不为 None,则保存差异图到该路径
|
| 13 |
+
:return: 一个字典,包含一些差异统计信息
|
| 14 |
+
"""
|
| 15 |
+
# 打开图片并转换为 RGB(避免有的图片是 RGBA/灰度模式)
|
| 16 |
+
img1 = Image.open(img_path1).convert("RGB")
|
| 17 |
+
img2 = Image.open(img_path2).convert("RGB")
|
| 18 |
+
|
| 19 |
+
# 检查分辨率是否一致
|
| 20 |
+
if img1.size != img2.size:
|
| 21 |
+
raise ValueError(f"两张图分辨率不同: {img1.size} vs {img2.size}")
|
| 22 |
+
|
| 23 |
+
# 转为 NumPy 数组,形状为 (H, W, 3)
|
| 24 |
+
arr1 = np.array(img1, dtype=np.int16) # 用 int16 以免减法溢出
|
| 25 |
+
arr2 = np.array(img2, dtype=np.int16)
|
| 26 |
+
|
| 27 |
+
# 逐像素逐通道做差,取绝对值
|
| 28 |
+
diff = np.abs(arr1 - arr2) # (H, W, 3)
|
| 29 |
+
|
| 30 |
+
# 每个像素的“差异强度”可以用 RGB 差分的和或均值来表示
|
| 31 |
+
# 这里用每像素的 RGB 差的平均值
|
| 32 |
+
per_pixel_diff = diff.mean(axis=2) # (H, W)
|
| 33 |
+
|
| 34 |
+
# 一些统计信息
|
| 35 |
+
total_pixels = per_pixel_diff.size
|
| 36 |
+
# 差异为0的像素
|
| 37 |
+
same_pixels = np.sum(per_pixel_diff == 0)
|
| 38 |
+
different_pixels = total_pixels - same_pixels
|
| 39 |
+
|
| 40 |
+
max_diff = float(per_pixel_diff.max()) # 单像素最大平均差值
|
| 41 |
+
mean_diff = float(per_pixel_diff.mean()) # 所有像素平均差值
|
| 42 |
+
diff_ratio = different_pixels / total_pixels # 有差异像素占比
|
| 43 |
+
|
| 44 |
+
stats = {
|
| 45 |
+
"total_pixels": int(total_pixels),
|
| 46 |
+
"same_pixels": int(same_pixels),
|
| 47 |
+
"different_pixels": int(different_pixels),
|
| 48 |
+
"different_ratio": diff_ratio, # 0~1 之间
|
| 49 |
+
"max_diff_per_pixel": max_diff, # 0~255
|
| 50 |
+
"mean_diff_per_pixel": mean_diff
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# 如果需要输出一张差异图
|
| 54 |
+
if diff_output_path is not None:
|
| 55 |
+
# diff 目前是 0~255 范围内的 RGB 差值,可以直接保存成图像看
|
| 56 |
+
diff_img = np.clip(diff, 0, 255).astype(np.uint8)
|
| 57 |
+
diff_image = Image.fromarray(diff_img, mode="RGB")
|
| 58 |
+
diff_image.save(diff_output_path)
|
| 59 |
+
# 也可以考虑把差异增强一下再保存(例如乘个系数)
|
| 60 |
+
|
| 61 |
+
return stats
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
img1_path = "./results-test/dpo_scale_ablation/dpo_scale_0.0/0000010-seed-0.png"
|
| 66 |
+
img2_path = "./results-test/dpo_scale_ablation/dpo_scale_1.0/0000010-seed-0.png"
|
| 67 |
+
diff_img_path = "diff.png"
|
| 68 |
+
|
| 69 |
+
stats = compare_images(img1_path, img2_path, diff_output_path=diff_img_path)
|
| 70 |
+
|
| 71 |
+
print("比较结果:")
|
| 72 |
+
for k, v in stats.items():
|
| 73 |
+
print(f"{k}: {v}")
|
diffusion-dpo-test/compare_checkpoints.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""对比两个 checkpoint 的 safetensor 文件,检查权重是否真的在变化"""
|
| 3 |
+
import sys
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# ==============================
|
| 9 |
+
# 在这里手动填写 checkpoint 路径
|
| 10 |
+
# ==============================
|
| 11 |
+
CHECKPOINT_1 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors"
|
| 12 |
+
CHECKPOINT_2 = r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-60/lora_train_unet/adapter_model.safetensors"
|
| 13 |
+
# ==============================
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compare_safetensors(path1, path2):
|
| 17 |
+
print(f"\n{'='*80}")
|
| 18 |
+
print(f"对比两个 checkpoint:")
|
| 19 |
+
print(f" Checkpoint 1: {path1}")
|
| 20 |
+
print(f" Checkpoint 2: {path2}")
|
| 21 |
+
print(f"{'='*80}\n")
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
state_dict1 = load_file(path1)
|
| 25 |
+
state_dict2 = load_file(path2)
|
| 26 |
+
|
| 27 |
+
# 检查键是否一致
|
| 28 |
+
keys1 = set(state_dict1.keys())
|
| 29 |
+
keys2 = set(state_dict2.keys())
|
| 30 |
+
|
| 31 |
+
if keys1 != keys2:
|
| 32 |
+
print("⚠️ 警告: 两个 checkpoint 的键不一致!")
|
| 33 |
+
print(f" 只在 checkpoint1 中: {keys1 - keys2}")
|
| 34 |
+
print(f" 只在 checkpoint2 中: {keys2 - keys1}")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
print(f"✅ 两个 checkpoint 都有 {len(keys1)} 个参数张量\n")
|
| 38 |
+
|
| 39 |
+
# 统计差异
|
| 40 |
+
identical_count = 0
|
| 41 |
+
different_count = 0
|
| 42 |
+
max_diff_info = None
|
| 43 |
+
max_diff = 0
|
| 44 |
+
|
| 45 |
+
layer_diffs = {}
|
| 46 |
+
|
| 47 |
+
for key in sorted(keys1):
|
| 48 |
+
tensor1 = state_dict1[key]
|
| 49 |
+
tensor2 = state_dict2[key]
|
| 50 |
+
|
| 51 |
+
diff = (tensor2 - tensor1).float()
|
| 52 |
+
abs_diff = diff.abs()
|
| 53 |
+
|
| 54 |
+
max_abs_diff = abs_diff.max().item()
|
| 55 |
+
mean_abs_diff = abs_diff.mean().item()
|
| 56 |
+
|
| 57 |
+
if max_abs_diff == 0:
|
| 58 |
+
identical_count += 1
|
| 59 |
+
else:
|
| 60 |
+
different_count += 1
|
| 61 |
+
|
| 62 |
+
if max_abs_diff > max_diff:
|
| 63 |
+
max_diff = max_abs_diff
|
| 64 |
+
max_diff_info = {
|
| 65 |
+
'key': key,
|
| 66 |
+
'max_diff': max_abs_diff,
|
| 67 |
+
'mean_diff': mean_abs_diff,
|
| 68 |
+
'tensor1_max': tensor1.float().abs().max().item(),
|
| 69 |
+
'tensor2_max': tensor2.float().abs().max().item(),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
if '.lora_B.' in key:
|
| 73 |
+
layer_name = key.split('.lora_B.')[0]
|
| 74 |
+
if layer_name not in layer_diffs:
|
| 75 |
+
layer_diffs[layer_name] = {
|
| 76 |
+
'max_diff': max_abs_diff,
|
| 77 |
+
'mean_diff': mean_abs_diff,
|
| 78 |
+
'key': key
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
print(f"差异统计:")
|
| 82 |
+
print(f" 完全相同的参数: {identical_count} / {len(keys1)} ({identical_count/len(keys1)*100:.2f}%)")
|
| 83 |
+
print(f" 有变化的参数: {different_count} / {len(keys1)} ({different_count/len(keys1)*100:.2f}%)")
|
| 84 |
+
print()
|
| 85 |
+
|
| 86 |
+
if max_diff_info:
|
| 87 |
+
print(f"最大权重变化:")
|
| 88 |
+
print(f" 层: {max_diff_info['key']}")
|
| 89 |
+
print(f" 最大绝对差异: {max_diff_info['max_diff']:.6e}")
|
| 90 |
+
print(f" 平均绝对差异: {max_diff_info['mean_diff']:.6e}")
|
| 91 |
+
print(f" Checkpoint1 最大值: {max_diff_info['tensor1_max']:.6e}")
|
| 92 |
+
print(f" Checkpoint2 最大值: {max_diff_info['tensor2_max']:.6e}")
|
| 93 |
+
print()
|
| 94 |
+
|
| 95 |
+
key_layers = ['x_embedder', 'transformer_blocks.0', 'transformer_blocks.9',
|
| 96 |
+
'single_transformer_blocks.30', 'proj_out']
|
| 97 |
+
|
| 98 |
+
print("关键层的 lora_B 权重变化:")
|
| 99 |
+
print("-" * 80)
|
| 100 |
+
for layer_prefix in key_layers:
|
| 101 |
+
matching = [k for k in layer_diffs.keys() if layer_prefix in k]
|
| 102 |
+
if matching:
|
| 103 |
+
for layer_name in matching[:2]:
|
| 104 |
+
info = layer_diffs[layer_name]
|
| 105 |
+
print(f"\n层: {layer_name}")
|
| 106 |
+
print(f" 最大差异: {info['max_diff']:.6e}")
|
| 107 |
+
print(f" 平均差异: {info['mean_diff']:.6e}")
|
| 108 |
+
|
| 109 |
+
key = info['key']
|
| 110 |
+
t1 = state_dict1[key].float()
|
| 111 |
+
t2 = state_dict2[key].float()
|
| 112 |
+
print(f" Checkpoint1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}")
|
| 113 |
+
print(f" Checkpoint2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}")
|
| 114 |
+
|
| 115 |
+
print("\n" + "="*80)
|
| 116 |
+
print("lora_B 权重变化最大的前 10 个层:")
|
| 117 |
+
print("-" * 80)
|
| 118 |
+
sorted_layers = sorted(layer_diffs.items(), key=lambda x: x[1]['max_diff'], reverse=True)
|
| 119 |
+
|
| 120 |
+
for i, (layer_name, info) in enumerate(sorted_layers[:10], 1):
|
| 121 |
+
key = info['key']
|
| 122 |
+
t1 = state_dict1[key].float()
|
| 123 |
+
t2 = state_dict2[key].float()
|
| 124 |
+
print(f"\n{i}. {layer_name}")
|
| 125 |
+
print(f" 最大差异: {info['max_diff']:.6e}, 平均差异: {info['mean_diff']:.6e}")
|
| 126 |
+
print(f" Ckpt1: mean={t1.mean():.6e}, max={t1.abs().max():.6e}")
|
| 127 |
+
print(f" Ckpt2: mean={t2.mean():.6e}, max={t2.abs().max():.6e}")
|
| 128 |
+
|
| 129 |
+
print("\n" + "="*80)
|
| 130 |
+
if different_count == 0:
|
| 131 |
+
print("❌ 严重问题: 两个 checkpoint 完全相同,模型没有学习!")
|
| 132 |
+
elif different_count < len(keys1) * 0.1:
|
| 133 |
+
print(f"⚠️ 警告: 只有 {different_count/len(keys1)*100:.2f}% 的参数在变化,可能存在梯度阻塞")
|
| 134 |
+
else:
|
| 135 |
+
print(f"✅ 正常: {different_count/len(keys1)*100:.2f}% 的参数在变化")
|
| 136 |
+
if max_diff < 1e-6:
|
| 137 |
+
print(f"⚠️ 但是: 最大变化只有 {max_diff:.6e},变化幅度可能太小")
|
| 138 |
+
print("="*80)
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"❌ 错误: {e}")
|
| 142 |
+
import traceback
|
| 143 |
+
traceback.print_exc()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
compare_safetensors(CHECKPOINT_1, CHECKPOINT_2)
|
diffusion-dpo-test/data_val/0000009-seed-0.png
ADDED
|
diffusion-dpo-test/data_val/0000010-seed-0.png
ADDED
|
diffusion-dpo-test/fix_lora_keys.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
修复已保存的 LoRA checkpoint 的 key 格式
|
| 4 |
+
将 PEFT 格式转换为 Diffusers load_lora_weights 期望的格式
|
| 5 |
+
|
| 6 |
+
PEFT 格式: x_embedder.lora_A.train_unet.weight
|
| 7 |
+
Diffusers 格式: transformer.x_embedder.lora_A.weight
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from safetensors.torch import load_file, save_file
|
| 12 |
+
|
| 13 |
+
def fix_lora_keys(input_path, output_path=None):
|
| 14 |
+
"""
|
| 15 |
+
修复 LoRA checkpoint 的 key 格式
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
input_path: 输入的 safetensor 文件路径
|
| 19 |
+
output_path: 输出路径,默认覆盖原文件(会先备份)
|
| 20 |
+
"""
|
| 21 |
+
if output_path is None:
|
| 22 |
+
output_path = input_path
|
| 23 |
+
|
| 24 |
+
print(f"\n{'='*80}")
|
| 25 |
+
print(f"修复 LoRA Key 格式")
|
| 26 |
+
print(f" 输入: {input_path}")
|
| 27 |
+
print(f" 输出: {output_path}")
|
| 28 |
+
print(f"{'='*80}\n")
|
| 29 |
+
|
| 30 |
+
# 加载原始 state_dict
|
| 31 |
+
state_dict = load_file(input_path)
|
| 32 |
+
print(f"加载了 {len(state_dict)} 个参数\n")
|
| 33 |
+
|
| 34 |
+
# 显示原始 key 格式
|
| 35 |
+
sample_key = list(state_dict.keys())[0]
|
| 36 |
+
print(f"原始 key 格式示例: {sample_key}")
|
| 37 |
+
|
| 38 |
+
# 检查是否需要修复
|
| 39 |
+
needs_fix = False
|
| 40 |
+
if "train_unet" in sample_key:
|
| 41 |
+
needs_fix = True
|
| 42 |
+
print(" ✓ 检测到 '.train_unet.' 后缀,需要移除")
|
| 43 |
+
if not sample_key.startswith("transformer."):
|
| 44 |
+
needs_fix = True
|
| 45 |
+
print(" ✓ 缺少 'transformer.' 前缀,需要添加")
|
| 46 |
+
|
| 47 |
+
if not needs_fix:
|
| 48 |
+
print("\n✅ Key 格式已经正确,无需修复!")
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
# 修复 key 格式
|
| 52 |
+
print("\n开始修复...")
|
| 53 |
+
new_state_dict = {}
|
| 54 |
+
for k, v in state_dict.items():
|
| 55 |
+
new_k = k
|
| 56 |
+
# 1. 移除 base_model.model. 前缀(如果有)
|
| 57 |
+
new_k = new_k.replace("base_model.model.", "")
|
| 58 |
+
# 2. 移除 .train_unet 后缀
|
| 59 |
+
new_k = new_k.replace(".train_unet.", ".")
|
| 60 |
+
# 3. 添加 transformer. 前缀
|
| 61 |
+
if not new_k.startswith("transformer."):
|
| 62 |
+
new_k = "transformer." + new_k
|
| 63 |
+
new_state_dict[new_k] = v
|
| 64 |
+
|
| 65 |
+
# 显示修复后的 key 格式
|
| 66 |
+
new_sample_key = list(new_state_dict.keys())[0]
|
| 67 |
+
print(f"修复后 key 格式示例: {new_sample_key}")
|
| 68 |
+
|
| 69 |
+
# 备份原文件(如果覆盖)
|
| 70 |
+
if output_path == input_path:
|
| 71 |
+
backup_path = input_path + ".backup"
|
| 72 |
+
print(f"\n备份原文件到: {backup_path}")
|
| 73 |
+
os.rename(input_path, backup_path)
|
| 74 |
+
|
| 75 |
+
# 保存修复后的文件
|
| 76 |
+
save_file(new_state_dict, output_path)
|
| 77 |
+
print(f"\n✅ 已保存修复后的文件: {output_path}")
|
| 78 |
+
|
| 79 |
+
# 验证
|
| 80 |
+
print("\n验证修复结果...")
|
| 81 |
+
verify_dict = load_file(output_path)
|
| 82 |
+
verify_key = list(verify_dict.keys())[0]
|
| 83 |
+
|
| 84 |
+
if verify_key.startswith("transformer.") and ".train_unet." not in verify_key:
|
| 85 |
+
print("✅ 验证通过!Key 格式正确")
|
| 86 |
+
else:
|
| 87 |
+
print(f"❌ 验证失败!Key 格式仍有问题: {verify_key}")
|
| 88 |
+
|
| 89 |
+
return new_state_dict
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fix_checkpoint_dir(checkpoint_dir):
|
| 93 |
+
"""修复整个 checkpoint 目录"""
|
| 94 |
+
lora_dir = os.path.join(checkpoint_dir, "lora_train_unet")
|
| 95 |
+
adapter_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
| 96 |
+
|
| 97 |
+
if os.path.exists(adapter_path):
|
| 98 |
+
fix_lora_keys(adapter_path)
|
| 99 |
+
else:
|
| 100 |
+
print(f"❌ 找不到文件: {adapter_path}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
if len(sys.argv) < 2:
|
| 105 |
+
print("用法:")
|
| 106 |
+
print(" python fix_lora_keys.py <checkpoint_path>")
|
| 107 |
+
print(" python fix_lora_keys.py <checkpoint_dir>")
|
| 108 |
+
print()
|
| 109 |
+
print("示例:")
|
| 110 |
+
print(" python fix_lora_keys.py results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors")
|
| 111 |
+
print(" python fix_lora_keys.py results_1202_4/checkpoint-15")
|
| 112 |
+
print()
|
| 113 |
+
|
| 114 |
+
# 默认修复 results_1202_4 下的所有 checkpoint
|
| 115 |
+
base_dir = "/home/wanghongbo06/diffusion-dpo-adv/results_1202_4"
|
| 116 |
+
if os.path.exists(base_dir):
|
| 117 |
+
print(f"自动扫描 {base_dir} 下的所有 checkpoint...")
|
| 118 |
+
for item in sorted(os.listdir(base_dir)):
|
| 119 |
+
if item.startswith("checkpoint-"):
|
| 120 |
+
ckpt_path = os.path.join(base_dir, item)
|
| 121 |
+
fix_checkpoint_dir(ckpt_path)
|
| 122 |
+
else:
|
| 123 |
+
print(f"默认目录 {base_dir} 不存在")
|
| 124 |
+
else:
|
| 125 |
+
path = sys.argv[1]
|
| 126 |
+
if path.endswith(".safetensors"):
|
| 127 |
+
fix_lora_keys(path)
|
| 128 |
+
elif os.path.isdir(path):
|
| 129 |
+
fix_checkpoint_dir(path)
|
| 130 |
+
else:
|
| 131 |
+
print(f"❌ 无效路径: {path}")
|
| 132 |
+
|
diffusion-dpo-test/inspect_safetensor.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""检查 safetensor 文件的内容"""
|
| 3 |
+
from safetensors.torch import load_file
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# ==========================================================
|
| 7 |
+
# 在这里手动填写 safetensors 文件路径(可填写多个)
|
| 8 |
+
# 示例:
|
| 9 |
+
# SAFETENSOR_PATHS = [
|
| 10 |
+
# r"/path/to/adapter_model1.safetensors",
|
| 11 |
+
# r"/path/to/adapter_model2.safetensors",
|
| 12 |
+
# ]
|
| 13 |
+
# ==========================================================
|
| 14 |
+
SAFETENSOR_PATHS = [
|
| 15 |
+
r"/home/wanghongbo06/diffusion-dpo-adv/results_1202_4/checkpoint-15/lora_train_unet/adapter_model.safetensors",
|
| 16 |
+
]
|
| 17 |
+
# ==========================================================
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def inspect_safetensor(path):
|
| 21 |
+
print(f"\n{'='*80}")
|
| 22 |
+
print(f"检查文件: {path}")
|
| 23 |
+
print(f"{'='*80}\n")
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
state_dict = load_file(path)
|
| 27 |
+
|
| 28 |
+
print(f"总共有 {len(state_dict)} 个参数张量\n")
|
| 29 |
+
|
| 30 |
+
total_params = 0
|
| 31 |
+
zero_params = 0
|
| 32 |
+
non_zero_params = 0
|
| 33 |
+
|
| 34 |
+
layer_stats = {}
|
| 35 |
+
|
| 36 |
+
for key, tensor in state_dict.items():
|
| 37 |
+
num_params = tensor.numel()
|
| 38 |
+
total_params += num_params
|
| 39 |
+
|
| 40 |
+
non_zero_count = (tensor != 0).sum().item()
|
| 41 |
+
zero_count = num_params - non_zero_count
|
| 42 |
+
|
| 43 |
+
if non_zero_count == 0:
|
| 44 |
+
zero_params += num_params
|
| 45 |
+
else:
|
| 46 |
+
non_zero_params += num_params
|
| 47 |
+
|
| 48 |
+
if '.lora_A.' in key or '.lora_B.' in key:
|
| 49 |
+
layer_name = key.split('.lora_')[0]
|
| 50 |
+
if layer_name not in layer_stats:
|
| 51 |
+
layer_stats[layer_name] = {
|
| 52 |
+
'lora_A': None,
|
| 53 |
+
'lora_B': None
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
stats_entry = {
|
| 57 |
+
'mean': tensor.float().mean().item(),
|
| 58 |
+
'std': tensor.float().std().item(),
|
| 59 |
+
'max': tensor.float().max().item(),
|
| 60 |
+
'min': tensor.float().min().item(),
|
| 61 |
+
'non_zero_ratio': non_zero_count / num_params
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
if '.lora_A.' in key:
|
| 65 |
+
layer_stats[layer_name]['lora_A'] = stats_entry
|
| 66 |
+
else:
|
| 67 |
+
layer_stats[layer_name]['lora_B'] = stats_entry
|
| 68 |
+
|
| 69 |
+
print(f"参数统计:")
|
| 70 |
+
print(f" 总参数数: {total_params:,}")
|
| 71 |
+
print(f" 非零参数: {non_zero_params:,} ({non_zero_params/total_params*100:.2f}%)")
|
| 72 |
+
print(f" 零参数: {zero_params:,} ({zero_params/total_params*100:.2f}%)")
|
| 73 |
+
print()
|
| 74 |
+
|
| 75 |
+
key_layers = ['x_embedder', 'transformer_blocks.0', 'transformer_blocks.9',
|
| 76 |
+
'single_transformer_blocks.30', 'proj_out']
|
| 77 |
+
|
| 78 |
+
print("关键层统计:")
|
| 79 |
+
print("-" * 80)
|
| 80 |
+
for layer_name in key_layers:
|
| 81 |
+
matching_layers = [k for k in layer_stats.keys() if layer_name in k]
|
| 82 |
+
if matching_layers:
|
| 83 |
+
for full_layer in matching_layers[:3]:
|
| 84 |
+
stats = layer_stats[full_layer]
|
| 85 |
+
print(f"\n层: {full_layer}")
|
| 86 |
+
if stats['lora_A']:
|
| 87 |
+
print(f" lora_A: mean={stats['lora_A']['mean']:.6e}, "
|
| 88 |
+
f"max={stats['lora_A']['max']:.6e}, "
|
| 89 |
+
f"非零比例={stats['lora_A']['non_zero_ratio']*100:.2f}%")
|
| 90 |
+
if stats['lora_B']:
|
| 91 |
+
print(f" lora_B: mean={stats['lora_B']['mean']:.6e}, "
|
| 92 |
+
f"max={stats['lora_B']['max']:.6e}, "
|
| 93 |
+
f"非零比例={stats['lora_B']['non_zero_ratio']*100:.2f}%")
|
| 94 |
+
|
| 95 |
+
print("\n" + "="*80)
|
| 96 |
+
print("lora_B 权重最大的前 5 个层:")
|
| 97 |
+
print("-" * 80)
|
| 98 |
+
lora_b_layers = [(k, v['lora_B']) for k, v in layer_stats.items() if v['lora_B'] is not None]
|
| 99 |
+
lora_b_layers.sort(key=lambda x: abs(x[1]['max']), reverse=True)
|
| 100 |
+
|
| 101 |
+
for i, (layer_name, stats) in enumerate(lora_b_layers[:5], 1):
|
| 102 |
+
print(f"{i}. {layer_name}")
|
| 103 |
+
print(f" mean={stats['mean']:.6e}, max={stats['max']:.6e}, "
|
| 104 |
+
f"非零比例={stats['non_zero_ratio']*100:.2f}%")
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"❌ 错误: {e}")
|
| 108 |
+
import traceback
|
| 109 |
+
traceback.print_exc()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
# 遍历手动填写的文件路径
|
| 114 |
+
for path in SAFETENSOR_PATHS:
|
| 115 |
+
inspect_safetensor(path)
|
diffusion-dpo-test/metrics.json
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"summary": {
|
| 3 |
+
"avg_inference_time_sec": 23.599896386852894,
|
| 4 |
+
"std_inference_time_sec": 6.796912376816178,
|
| 5 |
+
"min_inference_time_sec": 11.098074495792389,
|
| 6 |
+
"max_inference_time_sec": 28.76287134224549,
|
| 7 |
+
"median_inference_time_sec": 26.921020144131035,
|
| 8 |
+
"p95_inference_time_sec": 28.102163634379394,
|
| 9 |
+
"p99_inference_time_sec": 28.270882021300498,
|
| 10 |
+
"throughput_single_gpu_per_sec": 0.042373067390121394,
|
| 11 |
+
"throughput_parallel_per_sec": 0.08250964031858578,
|
| 12 |
+
"peak_memory_mb": 33811.7392578125,
|
| 13 |
+
"peak_memory_gb": 33.01927661895752,
|
| 14 |
+
"total_images": 100,
|
| 15 |
+
"warmup_images": 8,
|
| 16 |
+
"measured_images": 92,
|
| 17 |
+
"model_load_time_sec": 52.43063974380493,
|
| 18 |
+
"inference_wall_time_sec": 1211.9795894622803,
|
| 19 |
+
"total_time_sec": 1264.4102292060852,
|
| 20 |
+
"num_gpus": 2
|
| 21 |
+
},
|
| 22 |
+
"per_gpu_metrics": {
|
| 23 |
+
"1": {
|
| 24 |
+
"inference_times": [
|
| 25 |
+
27.854881670325994,
|
| 26 |
+
27.943492013029754,
|
| 27 |
+
27.888656420167536,
|
| 28 |
+
27.804196048993617,
|
| 29 |
+
27.924615799915045,
|
| 30 |
+
27.952690774109215,
|
| 31 |
+
27.882505219895393,
|
| 32 |
+
27.896253335289657,
|
| 33 |
+
28.09671812877059,
|
| 34 |
+
27.85268287314102,
|
| 35 |
+
27.742382244672626,
|
| 36 |
+
28.76287134224549,
|
| 37 |
+
28.061799827963114,
|
| 38 |
+
28.054252550937235,
|
| 39 |
+
27.92581091215834,
|
| 40 |
+
27.954700701870024,
|
| 41 |
+
27.952801280654967,
|
| 42 |
+
27.932732206769288,
|
| 43 |
+
28.13525341823697,
|
| 44 |
+
27.90626529790461,
|
| 45 |
+
28.22222373681143,
|
| 46 |
+
27.82774563319981,
|
| 47 |
+
27.953424714971334,
|
| 48 |
+
28.111765013076365,
|
| 49 |
+
27.881029167678207,
|
| 50 |
+
27.89013520721346,
|
| 51 |
+
27.974640298169106,
|
| 52 |
+
28.015457140281796,
|
| 53 |
+
28.10881925234571,
|
| 54 |
+
27.839136187918484,
|
| 55 |
+
27.980245498009026,
|
| 56 |
+
27.941782257054,
|
| 57 |
+
27.858478610869497,
|
| 58 |
+
18.717311061918736,
|
| 59 |
+
11.598434458952397,
|
| 60 |
+
11.579710375983268,
|
| 61 |
+
11.584534626919776,
|
| 62 |
+
11.59768055099994,
|
| 63 |
+
11.199870459735394,
|
| 64 |
+
11.196961861103773,
|
| 65 |
+
11.191290821880102,
|
| 66 |
+
11.286846159026027,
|
| 67 |
+
11.2286821231246,
|
| 68 |
+
11.219772449228913,
|
| 69 |
+
11.226453838404268,
|
| 70 |
+
11.211608635727316
|
| 71 |
+
],
|
| 72 |
+
"warmup_time": 115.76587492786348,
|
| 73 |
+
"peak_memory_mb": 33810.9892578125,
|
| 74 |
+
"allocated_memory_mb": 33196.46240234375,
|
| 75 |
+
"reserved_memory_mb": 34506.0,
|
| 76 |
+
"total_images": 50,
|
| 77 |
+
"avg_inference_time": 23.434121787122898,
|
| 78 |
+
"std_inference_time": 7.309697334208333,
|
| 79 |
+
"throughput": 0.04267281740208,
|
| 80 |
+
"memory_efficiency": 96.20489886496189
|
| 81 |
+
},
|
| 82 |
+
"0": {
|
| 83 |
+
"inference_times": [
|
| 84 |
+
26.802028878591955,
|
| 85 |
+
26.797191261779517,
|
| 86 |
+
26.82992110401392,
|
| 87 |
+
26.958114746958017,
|
| 88 |
+
26.761777761392295,
|
| 89 |
+
26.84792256169021,
|
| 90 |
+
26.746575728990138,
|
| 91 |
+
27.076817566063255,
|
| 92 |
+
26.943395509850234,
|
| 93 |
+
26.85071571683511,
|
| 94 |
+
26.841394792776555,
|
| 95 |
+
27.631777914240956,
|
| 96 |
+
26.833319212775677,
|
| 97 |
+
26.806214389856905,
|
| 98 |
+
26.8738283761777,
|
| 99 |
+
26.895515635609627,
|
| 100 |
+
26.940260547678918,
|
| 101 |
+
27.09848231682554,
|
| 102 |
+
26.858372538816184,
|
| 103 |
+
26.918541864026338,
|
| 104 |
+
26.94270030176267,
|
| 105 |
+
26.95117418281734,
|
| 106 |
+
26.760037765838206,
|
| 107 |
+
26.82909763790667,
|
| 108 |
+
26.831056244205683,
|
| 109 |
+
26.92349842423573,
|
| 110 |
+
26.80812106281519,
|
| 111 |
+
26.730416806880385,
|
| 112 |
+
27.080423870123923,
|
| 113 |
+
27.055579679086804,
|
| 114 |
+
27.27255716919899,
|
| 115 |
+
27.117452350910753,
|
| 116 |
+
26.751979127991945,
|
| 117 |
+
26.876946676988155,
|
| 118 |
+
26.70633071102202,
|
| 119 |
+
26.0279257488437,
|
| 120 |
+
25.012685468886048,
|
| 121 |
+
11.098074495792389,
|
| 122 |
+
11.139604906085879,
|
| 123 |
+
11.128950875252485,
|
| 124 |
+
11.235403429716825,
|
| 125 |
+
11.146971406880766,
|
| 126 |
+
11.146004664245993,
|
| 127 |
+
11.10717493435368,
|
| 128 |
+
11.114083136897534,
|
| 129 |
+
11.114445879124105
|
| 130 |
+
],
|
| 131 |
+
"warmup_time": 111.6877530622296,
|
| 132 |
+
"peak_memory_mb": 33811.7392578125,
|
| 133 |
+
"allocated_memory_mb": 33197.21240234375,
|
| 134 |
+
"reserved_memory_mb": 34506.0,
|
| 135 |
+
"total_images": 50,
|
| 136 |
+
"avg_inference_time": 23.76567098658289,
|
| 137 |
+
"std_inference_time": 6.237739828068353,
|
| 138 |
+
"throughput": 0.042077499118983785,
|
| 139 |
+
"memory_efficiency": 96.20707239999928
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
}
|
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000263-seed-0.png
ADDED
|
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000463-seed-0.png
ADDED
|
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000563-seed-0.png
ADDED
|
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000763-seed-0.png
ADDED
|
diffusion-dpo-test/results-test/DIV2K-val-epoch10/0000863-seed-0.png
ADDED
|
diffusion-dpo-test/results-test/DrealSR/sony_160_x4.png
ADDED
|
diffusion-dpo-test/results-test/DrealSR/sony_189_x4.png
ADDED
|
diffusion-dpo-test/src/flux/__pycache__/block.cpython-310.pyc
ADDED
|
Binary file (5.96 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/block.cpython-311.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/condition.cpython-310.pyc
ADDED
|
Binary file (3.32 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/condition.cpython-311.pyc
ADDED
|
Binary file (5.83 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/generate.cpython-310.pyc
ADDED
|
Binary file (5.79 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/generate.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-310.pyc
ADDED
|
Binary file (3 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/lora_controller.cpython-311.pyc
ADDED
|
Binary file (5.13 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-310.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/pipeline_tools.cpython-311.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
diffusion-dpo-test/src/flux/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|