chenchaoyun
commited on
Commit
·
8a25edd
1
Parent(s):
d11ff01
fix
Browse files- api_routes.py +7 -7
- test/celebrity_crawler.py +227 -0
- test/dow_img.py +24 -0
- test/howcuteami.py +202 -0
- test/import_history_images.py +162 -0
- test/test_deepface.py +38 -0
- test/test_main.http +11 -0
- test/test_rvm_infer.py +46 -0
- test/test_score.py +26 -0
- test/test_score_adjustment_demo.py +30 -0
- test/test_sky.py +15 -0
api_routes.py
CHANGED
|
@@ -731,7 +731,7 @@ async def _refresh_celebrity_cache(sample_image_path: str,
|
|
| 731 |
img_path=sample_image_path,
|
| 732 |
db_path=db_path,
|
| 733 |
model_name="ArcFace",
|
| 734 |
-
detector_backend="
|
| 735 |
distance_metric="cosine",
|
| 736 |
enforce_detection=True,
|
| 737 |
silent=True,
|
|
@@ -748,7 +748,7 @@ async def _refresh_celebrity_cache(sample_image_path: str,
|
|
| 748 |
img_path=sample_image_path,
|
| 749 |
db_path=db_path,
|
| 750 |
model_name="ArcFace",
|
| 751 |
-
detector_backend="
|
| 752 |
distance_metric="cosine",
|
| 753 |
enforce_detection=True,
|
| 754 |
silent=True,
|
|
@@ -768,7 +768,7 @@ async def _refresh_celebrity_cache(sample_image_path: str,
|
|
| 768 |
img_path=sample_image_path,
|
| 769 |
db_path=db_path,
|
| 770 |
model_name="ArcFace",
|
| 771 |
-
detector_backend="
|
| 772 |
distance_metric="cosine",
|
| 773 |
enforce_detection=True,
|
| 774 |
silent=True,
|
|
@@ -3905,7 +3905,7 @@ async def match_celebrity_face(
|
|
| 3905 |
img_path=temp_path,
|
| 3906 |
db_path=db_path,
|
| 3907 |
model_name="ArcFace",
|
| 3908 |
-
detector_backend="
|
| 3909 |
distance_metric="cosine",
|
| 3910 |
enforce_detection=True,
|
| 3911 |
silent=True,
|
|
@@ -4224,7 +4224,7 @@ async def face_similarity_verification(
|
|
| 4224 |
img1_path=original_path1,
|
| 4225 |
img2_path=original_path2,
|
| 4226 |
model_name="ArcFace",
|
| 4227 |
-
detector_backend="
|
| 4228 |
distance_metric="cosine"
|
| 4229 |
)
|
| 4230 |
logger.info(
|
|
@@ -4240,7 +4240,7 @@ async def face_similarity_verification(
|
|
| 4240 |
img1_path=original_path1,
|
| 4241 |
img2_path=original_path2,
|
| 4242 |
model_name="ArcFace",
|
| 4243 |
-
detector_backend="
|
| 4244 |
distance_metric="cosine"
|
| 4245 |
)
|
| 4246 |
logger.info(
|
|
@@ -4262,7 +4262,7 @@ async def face_similarity_verification(
|
|
| 4262 |
img1_path=original_path1,
|
| 4263 |
img2_path=original_path2,
|
| 4264 |
model_name="ArcFace",
|
| 4265 |
-
detector_backend="
|
| 4266 |
distance_metric="cosine"
|
| 4267 |
)
|
| 4268 |
logger.info(
|
|
|
|
| 731 |
img_path=sample_image_path,
|
| 732 |
db_path=db_path,
|
| 733 |
model_name="ArcFace",
|
| 734 |
+
detector_backend="yolov11n",
|
| 735 |
distance_metric="cosine",
|
| 736 |
enforce_detection=True,
|
| 737 |
silent=True,
|
|
|
|
| 748 |
img_path=sample_image_path,
|
| 749 |
db_path=db_path,
|
| 750 |
model_name="ArcFace",
|
| 751 |
+
detector_backend="yolov11n",
|
| 752 |
distance_metric="cosine",
|
| 753 |
enforce_detection=True,
|
| 754 |
silent=True,
|
|
|
|
| 768 |
img_path=sample_image_path,
|
| 769 |
db_path=db_path,
|
| 770 |
model_name="ArcFace",
|
| 771 |
+
detector_backend="yolov11n",
|
| 772 |
distance_metric="cosine",
|
| 773 |
enforce_detection=True,
|
| 774 |
silent=True,
|
|
|
|
| 3905 |
img_path=temp_path,
|
| 3906 |
db_path=db_path,
|
| 3907 |
model_name="ArcFace",
|
| 3908 |
+
detector_backend="yolov11n",
|
| 3909 |
distance_metric="cosine",
|
| 3910 |
enforce_detection=True,
|
| 3911 |
silent=True,
|
|
|
|
| 4224 |
img1_path=original_path1,
|
| 4225 |
img2_path=original_path2,
|
| 4226 |
model_name="ArcFace",
|
| 4227 |
+
detector_backend="yolov11n",
|
| 4228 |
distance_metric="cosine"
|
| 4229 |
)
|
| 4230 |
logger.info(
|
|
|
|
| 4240 |
img1_path=original_path1,
|
| 4241 |
img2_path=original_path2,
|
| 4242 |
model_name="ArcFace",
|
| 4243 |
+
detector_backend="yolov11n",
|
| 4244 |
distance_metric="cosine"
|
| 4245 |
)
|
| 4246 |
logger.info(
|
|
|
|
| 4262 |
img1_path=original_path1,
|
| 4263 |
img2_path=original_path2,
|
| 4264 |
model_name="ArcFace",
|
| 4265 |
+
detector_backend="yolov11n",
|
| 4266 |
distance_metric="cosine"
|
| 4267 |
)
|
| 4268 |
logger.info(
|
test/celebrity_crawler.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import requests
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CelebrityCrawler:
|
| 10 |
+
def __init__(self, output_dir="celebrity_images"):
|
| 11 |
+
self.output_dir = output_dir
|
| 12 |
+
self.headers = {
|
| 13 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 14 |
+
}
|
| 15 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
def read_celebrities_from_txt(self, file_path):
|
| 18 |
+
"""
|
| 19 |
+
从txt文件读取明星信息
|
| 20 |
+
支持格式:
|
| 21 |
+
1. 姓名,职业
|
| 22 |
+
2. 姓名
|
| 23 |
+
"""
|
| 24 |
+
celebrities = []
|
| 25 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 26 |
+
for line in f:
|
| 27 |
+
line = line.strip()
|
| 28 |
+
if not line or line.startswith('#'):
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
parts = line.split(',')
|
| 32 |
+
name = parts[0].strip()
|
| 33 |
+
profession = parts[1].strip() if len(parts) > 1 else "明星"
|
| 34 |
+
|
| 35 |
+
celebrities.append({
|
| 36 |
+
'name': name,
|
| 37 |
+
'profession': profession
|
| 38 |
+
})
|
| 39 |
+
return celebrities
|
| 40 |
+
|
| 41 |
+
def search_bing_images(self, celebrity_name, max_images=20):
|
| 42 |
+
"""使用Bing图片搜索API获取图片URL"""
|
| 43 |
+
search_url = "https://www.bing.com/images/search"
|
| 44 |
+
params = {
|
| 45 |
+
'q': celebrity_name + " 明星",
|
| 46 |
+
'first': 0,
|
| 47 |
+
'count': max_images
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
response = requests.get(search_url, params=params, headers=self.headers,
|
| 52 |
+
timeout=10)
|
| 53 |
+
response.raise_for_status()
|
| 54 |
+
|
| 55 |
+
# 简单的HTML解析获取图片URL
|
| 56 |
+
import re
|
| 57 |
+
img_urls = re.findall(r'murl":"(.*?)"', response.text)
|
| 58 |
+
return img_urls[:max_images]
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"搜索 {celebrity_name} 时出错: {e}")
|
| 61 |
+
return []
|
| 62 |
+
|
| 63 |
+
def search_baidu_images(self, celebrity_name, max_images=20):
|
| 64 |
+
"""使用百度图片搜索获取图片URL"""
|
| 65 |
+
search_url = "https://image.baidu.com/search/acjson"
|
| 66 |
+
params = {
|
| 67 |
+
'tn': 'resultjson_com',
|
| 68 |
+
'word': celebrity_name + " 明星",
|
| 69 |
+
'pn': 0,
|
| 70 |
+
'rn': max_images,
|
| 71 |
+
'ie': 'utf-8'
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
response = requests.get(search_url, params=params, headers=self.headers,
|
| 76 |
+
timeout=10)
|
| 77 |
+
response.raise_for_status()
|
| 78 |
+
data = response.json()
|
| 79 |
+
|
| 80 |
+
img_urls = []
|
| 81 |
+
if 'data' in data:
|
| 82 |
+
for item in data['data']:
|
| 83 |
+
if 'thumbURL' in item:
|
| 84 |
+
img_urls.append(item['thumbURL'])
|
| 85 |
+
return img_urls[:max_images]
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"搜索 {celebrity_name} 时出错: {e}")
|
| 88 |
+
return []
|
| 89 |
+
|
| 90 |
+
def download_image(self, url, save_path):
|
| 91 |
+
"""下载单张图片"""
|
| 92 |
+
try:
|
| 93 |
+
response = requests.get(url, headers=self.headers, timeout=15)
|
| 94 |
+
response.raise_for_status()
|
| 95 |
+
|
| 96 |
+
# 验证是否为有效图片
|
| 97 |
+
img = Image.open(BytesIO(response.content))
|
| 98 |
+
|
| 99 |
+
# 过滤太小的图片
|
| 100 |
+
if img.size[0] < 100 or img.size[1] < 100:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
# 保存图片
|
| 104 |
+
img = img.convert('RGB')
|
| 105 |
+
img.save(save_path, 'JPEG', quality=95)
|
| 106 |
+
return True
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f" 下载失败: {str(e)[:50]}")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def crawl_celebrity_images(self, celebrity, max_images=20,
|
| 112 |
+
search_engine='baidu'):
|
| 113 |
+
"""爬取单个明星的图片"""
|
| 114 |
+
name = celebrity['name']
|
| 115 |
+
print(f"\n正在爬取: {name} ({celebrity['profession']})")
|
| 116 |
+
|
| 117 |
+
# 创建明星专属文件夹
|
| 118 |
+
celebrity_dir = Path(self.output_dir)
|
| 119 |
+
celebrity_dir.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
# 获取图片URL列表
|
| 122 |
+
if search_engine == 'baidu':
|
| 123 |
+
img_urls = self.search_baidu_images(name, max_images * 2)
|
| 124 |
+
else:
|
| 125 |
+
img_urls = self.search_bing_images(name, max_images * 2)
|
| 126 |
+
|
| 127 |
+
if not img_urls:
|
| 128 |
+
print(f" 未找到 {name} 的图片")
|
| 129 |
+
return 0
|
| 130 |
+
|
| 131 |
+
print(f" 找到 {len(img_urls)} 个图片链接")
|
| 132 |
+
|
| 133 |
+
# 下载图片
|
| 134 |
+
success_count = 0
|
| 135 |
+
for idx, url in enumerate(img_urls):
|
| 136 |
+
if success_count >= max_images:
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
save_path = celebrity_dir / f"{name}_{idx + 1:03d}.jpg"
|
| 140 |
+
|
| 141 |
+
# 跳过已存在的文件
|
| 142 |
+
if save_path.exists():
|
| 143 |
+
success_count += 1
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
print(f" 下载 {idx + 1}/{len(img_urls)}...", end=' ')
|
| 147 |
+
if self.download_image(url, save_path):
|
| 148 |
+
success_count += 1
|
| 149 |
+
print("✓")
|
| 150 |
+
else:
|
| 151 |
+
print("✗")
|
| 152 |
+
|
| 153 |
+
# 避免请求过快
|
| 154 |
+
time.sleep(0.5)
|
| 155 |
+
|
| 156 |
+
print(f" 成功下载 {success_count} 张图片")
|
| 157 |
+
return success_count
|
| 158 |
+
|
| 159 |
+
def crawl_all(self, txt_file, max_images_per_celebrity=20,
|
| 160 |
+
search_engine='baidu'):
|
| 161 |
+
"""爬取所有明星的图片"""
|
| 162 |
+
print("=" * 60)
|
| 163 |
+
print("明星照片爬取工具")
|
| 164 |
+
print("=" * 60)
|
| 165 |
+
|
| 166 |
+
# 读取明星列表
|
| 167 |
+
celebrities = self.read_celebrities_from_txt(txt_file)
|
| 168 |
+
print(f"\n从 {txt_file} 读取到 {len(celebrities)} 位明星")
|
| 169 |
+
|
| 170 |
+
# 统计信息
|
| 171 |
+
total_images = 0
|
| 172 |
+
failed_celebrities = []
|
| 173 |
+
|
| 174 |
+
# 爬取每位明星
|
| 175 |
+
for i, celebrity in enumerate(celebrities, 1):
|
| 176 |
+
print(f"\n[{i}/{len(celebrities)}]", end=' ')
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
count = self.crawl_celebrity_images(
|
| 180 |
+
celebrity,
|
| 181 |
+
max_images=max_images_per_celebrity,
|
| 182 |
+
search_engine=search_engine
|
| 183 |
+
)
|
| 184 |
+
total_images += count
|
| 185 |
+
|
| 186 |
+
if count == 0:
|
| 187 |
+
failed_celebrities.append(celebrity['name'])
|
| 188 |
+
|
| 189 |
+
# 每爬取5个明星后暂停一下
|
| 190 |
+
if i % 5 == 0:
|
| 191 |
+
print(f"\n 已完成 {i}/{len(celebrities)}, 休息3秒...")
|
| 192 |
+
time.sleep(3)
|
| 193 |
+
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f" 处理 {celebrity['name']} 时出错: {e}")
|
| 196 |
+
failed_celebrities.append(celebrity['name'])
|
| 197 |
+
|
| 198 |
+
# 输出统计
|
| 199 |
+
print("\n" + "=" * 60)
|
| 200 |
+
print("爬取完成!")
|
| 201 |
+
print("=" * 60)
|
| 202 |
+
print(f"总明星数: {len(celebrities)}")
|
| 203 |
+
print(f"成功爬取: {len(celebrities) - len(failed_celebrities)}")
|
| 204 |
+
print(f"失败数量: {len(failed_celebrities)}")
|
| 205 |
+
print(f"总图片数: {total_images}")
|
| 206 |
+
print(f"保存位置: {self.output_dir}")
|
| 207 |
+
|
| 208 |
+
if failed_celebrities:
|
| 209 |
+
print(f"\n失败的明星: {', '.join(failed_celebrities)}")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# 使用示例
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
# 创建爬虫实例
|
| 215 |
+
crawler = CelebrityCrawler(output_dir="celebrity_dataset")
|
| 216 |
+
|
| 217 |
+
# 从txt文件爬取
|
| 218 |
+
# txt文件格式示例:
|
| 219 |
+
# 周杰伦,歌手
|
| 220 |
+
# 刘德华,演员
|
| 221 |
+
# 范冰冰,演员
|
| 222 |
+
|
| 223 |
+
crawler.crawl_all(
|
| 224 |
+
txt_file="celebrity_real_names.txt", # 你的txt文件路径
|
| 225 |
+
max_images_per_celebrity=1, # 每位明星爬取的图片数量
|
| 226 |
+
search_engine='baidu' # 'baidu' 或 'bing'
|
| 227 |
+
)
|
test/dow_img.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
# 读取图片
|
| 4 |
+
img = cv2.imread("~/Pictures/header.png")
|
| 5 |
+
|
| 6 |
+
# 设置压缩质量(0-100,值越小压缩越狠,质量越差)
|
| 7 |
+
quality = 50
|
| 8 |
+
|
| 9 |
+
# 写入压缩后的图像(注意必须是 .webp)
|
| 10 |
+
cv2.imwrite(
|
| 11 |
+
"~/Pictures/output_small.webp",
|
| 12 |
+
img,
|
| 13 |
+
[int(cv2.IMWRITE_WEBP_QUALITY), quality],
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# # 读取原图
|
| 18 |
+
# img = cv2.imread("~/Pictures/header.png")
|
| 19 |
+
#
|
| 20 |
+
# # 缩放图像(例如缩小为原图的一半)
|
| 21 |
+
# resized = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2))
|
| 22 |
+
#
|
| 23 |
+
# # 写入压缩图像,降低质量
|
| 24 |
+
# cv2.imwrite("~/Pictures/output_small.webp", resized, [int(cv2.IMWRITE_WEBP_QUALITY), 40])
|
test/howcuteami.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# detect face
|
| 9 |
+
def highlightFace(net, frame, conf_threshold=0.95):
|
| 10 |
+
frameOpencvDnn = frame.copy()
|
| 11 |
+
frameHeight = frameOpencvDnn.shape[0]
|
| 12 |
+
frameWidth = frameOpencvDnn.shape[1]
|
| 13 |
+
blob = cv2.dnn.blobFromImage(
|
| 14 |
+
frameOpencvDnn, 1.0, (300, 300), [104, 117, 123], True, False
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
net.setInput(blob)
|
| 18 |
+
detections = net.forward()
|
| 19 |
+
faceBoxes = []
|
| 20 |
+
|
| 21 |
+
for i in range(detections.shape[2]):
|
| 22 |
+
confidence = detections[0, 0, i, 2]
|
| 23 |
+
if confidence > conf_threshold:
|
| 24 |
+
x1 = int(detections[0, 0, i, 3] * frameWidth)
|
| 25 |
+
y1 = int(detections[0, 0, i, 4] * frameHeight)
|
| 26 |
+
x2 = int(detections[0, 0, i, 5] * frameWidth)
|
| 27 |
+
y2 = int(detections[0, 0, i, 6] * frameHeight)
|
| 28 |
+
faceBoxes.append(scale([x1, y1, x2, y2]))
|
| 29 |
+
|
| 30 |
+
return faceBoxes
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# scale current rectangle to box
|
| 34 |
+
def scale(box):
|
| 35 |
+
width = box[2] - box[0]
|
| 36 |
+
height = box[3] - box[1]
|
| 37 |
+
maximum = max(width, height)
|
| 38 |
+
dx = int((maximum - width) / 2)
|
| 39 |
+
dy = int((maximum - height) / 2)
|
| 40 |
+
|
| 41 |
+
bboxes = [box[0] - dx, box[1] - dy, box[2] + dx, box[3] + dy]
|
| 42 |
+
return bboxes
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# crop image
|
| 46 |
+
def cropImage(image, box):
|
| 47 |
+
num = image[box[1] : box[3], box[0] : box[2]]
|
| 48 |
+
return num
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# main
|
| 52 |
+
parser = argparse.ArgumentParser()
|
| 53 |
+
parser.add_argument("-i", "--image", type=str, required=False, help="input image")
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
# 创建输出目录
|
| 57 |
+
output_dir = "../output"
|
| 58 |
+
if not os.path.exists(output_dir):
|
| 59 |
+
os.makedirs(output_dir)
|
| 60 |
+
|
| 61 |
+
faceProto = "models/opencv_face_detector.pbtxt"
|
| 62 |
+
faceModel = "models/opencv_face_detector_uint8.pb"
|
| 63 |
+
ageProto = "models/age_googlenet.prototxt"
|
| 64 |
+
ageModel = "models/age_googlenet.caffemodel"
|
| 65 |
+
genderProto = "models/gender_googlenet.prototxt"
|
| 66 |
+
genderModel = "models/gender_googlenet.caffemodel"
|
| 67 |
+
beautyProto = "models/beauty_resnet.prototxt"
|
| 68 |
+
beautyModel = "models/beauty_resnet.caffemodel"
|
| 69 |
+
|
| 70 |
+
MODEL_MEAN_VALUES = (104, 117, 123)
|
| 71 |
+
ageList = [
|
| 72 |
+
"(0-2)",
|
| 73 |
+
"(4-6)",
|
| 74 |
+
"(8-12)",
|
| 75 |
+
"(15-20)",
|
| 76 |
+
"(25-32)",
|
| 77 |
+
"(38-43)",
|
| 78 |
+
"(48-53)",
|
| 79 |
+
"(60-100)",
|
| 80 |
+
]
|
| 81 |
+
genderList = ["Male", "Female"]
|
| 82 |
+
|
| 83 |
+
# 定义性别对应的颜色 (BGR格式)
|
| 84 |
+
gender_colors = {
|
| 85 |
+
"Male": (255, 165, 0), # 橙色 Orange
|
| 86 |
+
"Female": (255, 0, 255), # 洋红 Magenta / Fuchsia
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
faceNet = cv2.dnn.readNet(faceModel, faceProto)
|
| 90 |
+
ageNet = cv2.dnn.readNet(ageModel, ageProto)
|
| 91 |
+
genderNet = cv2.dnn.readNet(genderModel, genderProto)
|
| 92 |
+
beautyNet = cv2.dnn.readNet(beautyModel, beautyProto)
|
| 93 |
+
|
| 94 |
+
# 读取图片
|
| 95 |
+
image_path = args.image if args.image else "images/charlize.jpg"
|
| 96 |
+
frame = cv2.imread(image_path)
|
| 97 |
+
|
| 98 |
+
if frame is None:
|
| 99 |
+
print(f"无法读取图片: {image_path}")
|
| 100 |
+
exit()
|
| 101 |
+
|
| 102 |
+
faceBoxes = highlightFace(faceNet, frame)
|
| 103 |
+
if not faceBoxes:
|
| 104 |
+
print("No face detected")
|
| 105 |
+
exit()
|
| 106 |
+
|
| 107 |
+
print(f"检测到 {len(faceBoxes)} 张人脸")
|
| 108 |
+
|
| 109 |
+
for i, faceBox in enumerate(faceBoxes):
|
| 110 |
+
# 提取人脸区域
|
| 111 |
+
face = cropImage(frame, faceBox)
|
| 112 |
+
face_resized = cv2.resize(face, (224, 224))
|
| 113 |
+
|
| 114 |
+
# gender net
|
| 115 |
+
blob = cv2.dnn.blobFromImage(
|
| 116 |
+
face_resized, 1.0, (224, 224), MODEL_MEAN_VALUES, swapRB=False
|
| 117 |
+
)
|
| 118 |
+
genderNet.setInput(blob)
|
| 119 |
+
genderPreds = genderNet.forward()
|
| 120 |
+
gender = genderList[genderPreds[0].argmax()]
|
| 121 |
+
print(f"Gender: {gender}")
|
| 122 |
+
|
| 123 |
+
# age net
|
| 124 |
+
ageNet.setInput(blob)
|
| 125 |
+
agePreds = ageNet.forward()
|
| 126 |
+
age = ageList[agePreds[0].argmax()]
|
| 127 |
+
print(f"Age: {age[1:-1]} years")
|
| 128 |
+
|
| 129 |
+
# beauty net
|
| 130 |
+
blob = cv2.dnn.blobFromImage(
|
| 131 |
+
face_resized, 1.0 / 255, (224, 224), MODEL_MEAN_VALUES, swapRB=False
|
| 132 |
+
)
|
| 133 |
+
beautyNet.setInput(blob)
|
| 134 |
+
beautyPreds = beautyNet.forward()
|
| 135 |
+
beauty = round(2.0 * sum(beautyPreds[0]), 1)
|
| 136 |
+
print(f"Beauty: {beauty}/10.0")
|
| 137 |
+
|
| 138 |
+
# 根据性别选择颜色
|
| 139 |
+
color = gender_colors[gender]
|
| 140 |
+
|
| 141 |
+
# 保存人脸图片 - 使用cv2.imwrite
|
| 142 |
+
face_filename = f"{output_dir}/face_{i+1}.webp"
|
| 143 |
+
cv2.imwrite(face_filename, face, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 144 |
+
print(f"人脸图片已保存: {face_filename}")
|
| 145 |
+
|
| 146 |
+
# 保存评分到图片上(可选)
|
| 147 |
+
face_with_text = face.copy()
|
| 148 |
+
cv2.putText(
|
| 149 |
+
face_with_text, f"{gender}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2
|
| 150 |
+
)
|
| 151 |
+
cv2.putText(
|
| 152 |
+
face_with_text,
|
| 153 |
+
f"{age[1:-1]} years",
|
| 154 |
+
(10, 60),
|
| 155 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 156 |
+
0.7,
|
| 157 |
+
color,
|
| 158 |
+
2,
|
| 159 |
+
)
|
| 160 |
+
cv2.putText(
|
| 161 |
+
face_with_text,
|
| 162 |
+
f"{beauty}/10.0",
|
| 163 |
+
(10, 90),
|
| 164 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 165 |
+
0.7,
|
| 166 |
+
color,
|
| 167 |
+
2,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
annotated_filename = f"{output_dir}/face_{i+1}_annotated.webp"
|
| 171 |
+
cv2.imwrite(annotated_filename, face_with_text, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 172 |
+
print(f"标注人脸已保存: {annotated_filename}")
|
| 173 |
+
|
| 174 |
+
# 在原图上绘制人脸框和信息
|
| 175 |
+
cv2.rectangle(
|
| 176 |
+
frame,
|
| 177 |
+
(faceBox[0], faceBox[1]),
|
| 178 |
+
(faceBox[2], faceBox[3]),
|
| 179 |
+
color,
|
| 180 |
+
int(round(frame.shape[0] / 400)),
|
| 181 |
+
8,
|
| 182 |
+
)
|
| 183 |
+
cv2.putText(
|
| 184 |
+
frame,
|
| 185 |
+
f"{gender}, {age}, {beauty}",
|
| 186 |
+
(faceBox[0], faceBox[1] - 10),
|
| 187 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 188 |
+
1.25,
|
| 189 |
+
color,
|
| 190 |
+
2,
|
| 191 |
+
cv2.LINE_AA,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# 保存完整的标注图片
|
| 195 |
+
result_filename = f"{output_dir}/result_full.webp"
|
| 196 |
+
cv2.imwrite(result_filename, frame, [cv2.IMWRITE_WEBP_QUALITY, 95])
|
| 197 |
+
print(f"完整结果图片已保存: {result_filename}")
|
| 198 |
+
|
| 199 |
+
# 显示图片
|
| 200 |
+
cv2.imshow("howbeautifulami", frame)
|
| 201 |
+
cv2.waitKey(0)
|
| 202 |
+
cv2.destroyAllWindows()
|
test/import_history_images.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
导入历史图片文件到数据库的脚本
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import hashlib
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# 添加项目根目录到Python路径
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 16 |
+
|
| 17 |
+
from database import record_image_creation, fetch_records_by_paths
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calculate_file_hash(file_path):
|
| 21 |
+
"""计算文件的MD5哈希值"""
|
| 22 |
+
hash_md5 = hashlib.md5()
|
| 23 |
+
with open(file_path, "rb") as f:
|
| 24 |
+
# 分块读取文件,避免大文件占用过多内存
|
| 25 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
| 26 |
+
hash_md5.update(chunk)
|
| 27 |
+
return hash_md5.hexdigest()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def infer_category_from_filename(filename):
|
| 31 |
+
"""从文件名推断类别"""
|
| 32 |
+
filename_lower = filename.lower()
|
| 33 |
+
|
| 34 |
+
# 处理动漫风格化类型
|
| 35 |
+
if '_anime_style_' in filename_lower:
|
| 36 |
+
return 'anime_style'
|
| 37 |
+
|
| 38 |
+
# 查找最后一个下划线和第一个点的位置
|
| 39 |
+
last_underscore_index = filename_lower.rfind('_')
|
| 40 |
+
first_dot_index = filename_lower.find('.', last_underscore_index)
|
| 41 |
+
|
| 42 |
+
# 如果找到了下划线和点,且下划线在点之前
|
| 43 |
+
if last_underscore_index != -1 and first_dot_index != -1 and last_underscore_index < first_dot_index:
|
| 44 |
+
# 提取下划线和点之间的内容
|
| 45 |
+
file_type = filename_lower[last_underscore_index + 1:first_dot_index]
|
| 46 |
+
|
| 47 |
+
# 根据类型返回中文描述
|
| 48 |
+
type_mapping = {
|
| 49 |
+
'restore': 'restore',
|
| 50 |
+
'upcolor': 'upcolor',
|
| 51 |
+
'grayscale': 'grayscale',
|
| 52 |
+
'upscale': 'upscale',
|
| 53 |
+
'compress': 'compress',
|
| 54 |
+
'id_photo': 'id_photo',
|
| 55 |
+
'grid': 'grid',
|
| 56 |
+
'rvm': 'rvm',
|
| 57 |
+
'celebrity': 'celebrity',
|
| 58 |
+
'face': 'face',
|
| 59 |
+
'original': 'original'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
return type_mapping.get(file_type, 'other')
|
| 63 |
+
|
| 64 |
+
# 默认返回 other
|
| 65 |
+
return 'other'
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
async def import_history_images(source_dir, nickname="system_import"):
|
| 69 |
+
"""导入历史图片到数据库"""
|
| 70 |
+
source_path = Path(source_dir)
|
| 71 |
+
|
| 72 |
+
if not source_path.exists():
|
| 73 |
+
print(f"错误: 目录 {source_dir} 不存在")
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
# 支持的图片格式
|
| 77 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif', '.tiff',
|
| 78 |
+
'.tif'}
|
| 79 |
+
|
| 80 |
+
# 获取所有图片文件
|
| 81 |
+
image_files = []
|
| 82 |
+
for ext in image_extensions:
|
| 83 |
+
image_files.extend(source_path.glob(f"*{ext}"))
|
| 84 |
+
image_files.extend(source_path.glob(f"*{ext.upper()}"))
|
| 85 |
+
|
| 86 |
+
print(f"找到 {len(image_files)} 个图片文件")
|
| 87 |
+
|
| 88 |
+
imported_count = 0
|
| 89 |
+
skipped_count = 0
|
| 90 |
+
|
| 91 |
+
for image_path in image_files:
|
| 92 |
+
try:
|
| 93 |
+
file_name = image_path.name
|
| 94 |
+
|
| 95 |
+
# 检查文件是否已存在于数据库中(基于文件名)
|
| 96 |
+
records = await fetch_records_by_paths([file_name])
|
| 97 |
+
|
| 98 |
+
if file_name in records:
|
| 99 |
+
print(f"跳过已存在的文件: {file_name}")
|
| 100 |
+
skipped_count += 1
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# 如果数据库中没有记录,则继续导入
|
| 104 |
+
# 计算文件哈希值用于进一步确认唯一性
|
| 105 |
+
file_hash = calculate_file_hash(str(image_path))
|
| 106 |
+
|
| 107 |
+
# 推断文件类别
|
| 108 |
+
category = infer_category_from_filename(file_name)
|
| 109 |
+
|
| 110 |
+
# 记录到数据库
|
| 111 |
+
await record_image_creation(
|
| 112 |
+
file_path=file_name, # 使用文件名而不是完整路径
|
| 113 |
+
nickname=nickname,
|
| 114 |
+
category=category,
|
| 115 |
+
bos_uploaded=False, # 历史文件通常未上传到BOS
|
| 116 |
+
score=0.0, # 历史文件默认分数为0
|
| 117 |
+
extra_metadata={
|
| 118 |
+
"source": "history_import",
|
| 119 |
+
"original_path": str(image_path),
|
| 120 |
+
"file_hash": file_hash,
|
| 121 |
+
"import_time": datetime.now().isoformat()
|
| 122 |
+
}
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
imported_count += 1
|
| 126 |
+
print(f"成功导入: {file_name} (类别: {category})")
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"导入文件失败 {image_path.name}: {str(e)}")
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
print(f"\n导入完成!")
|
| 133 |
+
print(f"成功导入: {imported_count} 个文件")
|
| 134 |
+
print(f"跳过: {skipped_count} 个文件")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
async def main():
|
| 138 |
+
if len(sys.argv) < 2:
|
| 139 |
+
print("用法: python import_history_images.py <图片目录路径> [昵称]")
|
| 140 |
+
print(
|
| 141 |
+
"示例: python import_history_images.py /Users/chenchaoyun/app/data/images")
|
| 142 |
+
print(
|
| 143 |
+
"示例: python import_history_images.py /Users/chenchaoyun/app/data/images \"历史导入\"")
|
| 144 |
+
sys.exit(1)
|
| 145 |
+
|
| 146 |
+
source_directory = sys.argv[1]
|
| 147 |
+
nickname = sys.argv[2] if len(sys.argv) > 2 else "system_import"
|
| 148 |
+
|
| 149 |
+
print(f"开始导入图片文件...")
|
| 150 |
+
print(f"源目录: {source_directory}")
|
| 151 |
+
print(f"用户昵称: {nickname}")
|
| 152 |
+
print("-" * 50)
|
| 153 |
+
|
| 154 |
+
start_time = time.time()
|
| 155 |
+
await import_history_images(source_directory, nickname)
|
| 156 |
+
end_time = time.time()
|
| 157 |
+
|
| 158 |
+
print(f"\n总耗时: {end_time - start_time:.2f} 秒")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
asyncio.run(main())
|
test/test_deepface.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
from deepface import DeepFace
|
| 4 |
+
|
| 5 |
+
images_path = "/Users/chenchaoyun/Pictures/face"
|
| 6 |
+
|
| 7 |
+
# ========== 2. 人脸相似度比对 ==========
|
| 8 |
+
start_time = time.time()
|
| 9 |
+
result_verification = DeepFace.verify(
|
| 10 |
+
img1_path=images_path + "/4.webp",
|
| 11 |
+
img2_path=images_path + "/5.webp",
|
| 12 |
+
model_name="ArcFace", # 指定模型
|
| 13 |
+
detector_backend="yolov11n", # 人脸检测器 retinaface / yolov8 / opencv / ssd / mediapipe
|
| 14 |
+
distance_metric="cosine" # 相似度度量
|
| 15 |
+
)
|
| 16 |
+
end_time = time.time()
|
| 17 |
+
print(f"🕒 人脸比对耗时: {end_time - start_time:.3f} 秒")
|
| 18 |
+
|
| 19 |
+
# 打印结果
|
| 20 |
+
print(json.dumps(result_verification, ensure_ascii=False, indent=2))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ========== 1. 人脸识别 ==========
|
| 24 |
+
|
| 25 |
+
start_time = time.time()
|
| 26 |
+
result_recognition = DeepFace.find(
|
| 27 |
+
img_path=images_path + "/1.jpg", # 待识别人脸
|
| 28 |
+
db_path=images_path, # 数据库路径
|
| 29 |
+
model_name="ArcFace", # 指定模型
|
| 30 |
+
detector_backend="yolov11n", # 人脸检测器
|
| 31 |
+
distance_metric="cosine" # 相似度度量
|
| 32 |
+
)
|
| 33 |
+
end_time = time.time()
|
| 34 |
+
print(f"🕒 人脸识别耗时: {end_time - start_time:.3f} 秒")
|
| 35 |
+
|
| 36 |
+
# 如果需要打印结果,可以取消注释
|
| 37 |
+
# df = result_recognition[0]
|
| 38 |
+
# print(df.to_json(orient="records", force_ascii=False))
|
test/test_main.http
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Test your FastAPI endpoints
|
| 2 |
+
|
| 3 |
+
GET http://127.0.0.1:8000/
|
| 4 |
+
Accept: application/json
|
| 5 |
+
|
| 6 |
+
###
|
| 7 |
+
|
| 8 |
+
GET http://127.0.0.1:8000/hello/User
|
| 9 |
+
Accept: application/json
|
| 10 |
+
|
| 11 |
+
###
|
test/test_rvm_infer.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
|
| 8 |
+
device = "cpu"
|
| 9 |
+
|
| 10 |
+
# 输入输出路径
|
| 11 |
+
input_path = "~/Pictures/face/yang.webp"
|
| 12 |
+
output_path = "~/Pictures/face/output_alpha.webp"
|
| 13 |
+
|
| 14 |
+
# ✅ 加载预训练模型 (resnet50)
|
| 15 |
+
model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").to(device).eval()
|
| 16 |
+
|
| 17 |
+
# 开始计时
|
| 18 |
+
start = time.time()
|
| 19 |
+
|
| 20 |
+
# 读图 (BGR->RGB)
|
| 21 |
+
img = cv2.imread(input_path)[:, :, ::-1].copy()
|
| 22 |
+
src = transforms.ToTensor()(img).unsqueeze(0).to(device)
|
| 23 |
+
|
| 24 |
+
# 推理
|
| 25 |
+
rec = [None] * 4
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25)
|
| 28 |
+
|
| 29 |
+
# 转 numpy
|
| 30 |
+
fgr = (fgr[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) # (H,W,3)
|
| 31 |
+
pha = (pha[0, 0].cpu().numpy() * 255).astype(np.uint8) # (H,W)
|
| 32 |
+
|
| 33 |
+
# 拼接 RGBA
|
| 34 |
+
rgba = np.dstack((fgr, pha)) # (H,W,4)
|
| 35 |
+
|
| 36 |
+
# 保存 WebP (带透明度)
|
| 37 |
+
cv2.imwrite(output_path, rgba[:, :, [2,1,0,3]], [cv2.IMWRITE_WEBP_QUALITY, 100]) # 转成 BGRA 顺序
|
| 38 |
+
|
| 39 |
+
# 结束计时
|
| 40 |
+
elapsed = time.time() - start
|
| 41 |
+
|
| 42 |
+
# 控制台日志输出
|
| 43 |
+
print(f"✅ RVM 抠图完成 (透明背景)")
|
| 44 |
+
print(f" 输入文件: {input_path}")
|
| 45 |
+
print(f" 输出文件: {output_path}")
|
| 46 |
+
print(f" 耗时: {elapsed:.3f} 秒 (设备: {device})")
|
test/test_score.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from retinaface import RetinaFace
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def default_converter(o):
|
| 9 |
+
if isinstance(o, np.integer):
|
| 10 |
+
return int(o)
|
| 11 |
+
if isinstance(o, np.floating):
|
| 12 |
+
return float(o)
|
| 13 |
+
if isinstance(o, np.ndarray):
|
| 14 |
+
return o.tolist()
|
| 15 |
+
return str(o)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# 配置日志
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
resp = RetinaFace.detect_faces("~/Downloads/chounan.jpeg")
|
| 23 |
+
|
| 24 |
+
logger.info(
|
| 25 |
+
"search results: " + json.dumps(resp, ensure_ascii=False, default=default_converter)
|
| 26 |
+
)
|
test/test_score_adjustment_demo.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def adjust_score(score, threshold, gamma):
|
| 2 |
+
"""根据阈值和gamma值调整评分"""
|
| 3 |
+
if score < threshold:
|
| 4 |
+
adjusted = threshold - gamma * (threshold - score)
|
| 5 |
+
return round(min(10.0, max(0.0, adjusted)), 1)
|
| 6 |
+
return score
|
| 7 |
+
|
| 8 |
+
# 默认参数 (T=9.0, γ=0.5)
|
| 9 |
+
default_threshold = 9.0
|
| 10 |
+
default_gamma = 0.5
|
| 11 |
+
|
| 12 |
+
# 新参数1 (T=8.0, γ=0.5)
|
| 13 |
+
new_threshold_1 = 9
|
| 14 |
+
new_gamma_1 = 0.9
|
| 15 |
+
|
| 16 |
+
# 新参数2 (T=8.0, γ=0.3)
|
| 17 |
+
new_threshold_2 = 9
|
| 18 |
+
new_gamma_2 = 0.8
|
| 19 |
+
|
| 20 |
+
print(f"原始分\tT={default_threshold},y={default_gamma}\tT={new_threshold_1},γ={new_gamma_1}\tT={new_threshold_2},γ={new_gamma_2}")
|
| 21 |
+
print("-----\t----------\t----------\t----------")
|
| 22 |
+
|
| 23 |
+
# 从1.0到10.0,以0.1为步长
|
| 24 |
+
for i in range(10, 101):
|
| 25 |
+
score = i / 10.0
|
| 26 |
+
default_adjusted = adjust_score(score, default_threshold, default_gamma)
|
| 27 |
+
new_adjusted_1 = adjust_score(score, new_threshold_1, new_gamma_1)
|
| 28 |
+
new_adjusted_2 = adjust_score(score, new_threshold_2, new_gamma_2)
|
| 29 |
+
# 确保显示小数点
|
| 30 |
+
print(f"{score:.1f}\t\t\t{default_adjusted:.1f}\t\t\t\t\t{new_adjusted_1:.1f}\t\t\t\t\t{new_adjusted_2:.1f}")
|
test/test_sky.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
from modelscope.outputs import OutputKeys
|
| 5 |
+
from modelscope.pipelines import pipeline
|
| 6 |
+
from modelscope.utils.constant import Tasks
|
| 7 |
+
|
| 8 |
+
image_skychange = pipeline(Tasks.image_skychange,
|
| 9 |
+
model='iic/cv_hrnetocr_skychange')
|
| 10 |
+
result = image_skychange(
|
| 11 |
+
{'sky_image': '/Users/chenchaoyun/Downloads/sky_image.jpg',
|
| 12 |
+
'scene_image': '/Users/chenchaoyun/Pictures/face/NXEo0zusSaNB2fa232c84898e92ff165e2dfee59cb54.jpg'})
|
| 13 |
+
cv2.imwrite('/Users/chenchaoyun/Downloads/result.png',
|
| 14 |
+
result[OutputKeys.OUTPUT_IMG])
|
| 15 |
+
print(f'Output written to {osp.abspath("result.png")}')
|