Humphreykowl commited on
Commit
60c0c4a
·
verified ·
1 Parent(s): 934e40b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -677
app.py CHANGED
@@ -1,723 +1,165 @@
1
- # app.py - 完整的单文件时尚AI应用
2
- # 无需复杂的模块导入,直接在 Hugging Face Spaces 运行
3
-
4
  import gradio as gr
5
- import torch
6
- import numpy as np
7
- from PIL import Image, ImageDraw, ImageFont
8
- import logging
9
- import time
10
- import random
11
- import os
12
- from datetime import datetime
13
- from typing import Dict, List, Optional
14
-
15
- # 设置日志
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
 
19
- # 尝试导入AI模型库
20
- try:
21
- from transformers import (
22
- BlipProcessor, BlipForConditionalGeneration,
23
- CLIPProcessor, CLIPModel,
24
- pipeline
25
- )
26
- from diffusers import (
27
- StableDiffusionPipeline,
28
- ControlNetModel,
29
- StableDiffusionControlNetPipeline
30
- )
31
- from sklearn.cluster import KMeans
32
- import cv2
33
- MODELS_AVAILABLE = True
34
- logger.info("✅ 所有AI库导入成功")
35
- except ImportError as e:
36
- logger.warning(f"⚠️ 部分AI库未安装: {e}")
37
- MODELS_AVAILABLE = False
38
 
39
- # 尝试导入 Hugging Face Spaces GPU 支持
40
- try:
41
- import spaces
42
- SPACES_GPU_AVAILABLE = True
43
- logger.info("✅ Hugging Face Spaces GPU 支持可用")
44
- except ImportError:
45
- SPACES_GPU_AVAILABLE = False
46
- logger.info("ℹ️ 运行在非Spaces环境或无GPU支持")
47
 
48
- class FashionAnalyzer:
49
- """时尚分析引擎"""
50
-
51
- def __init__(self):
52
- self.style_keywords = {
53
- "商务正装": ["suit", "formal", "business", "office", "professional", "blazer", "西装", "正装"],
54
- "休闲风格": ["casual", "relaxed", "comfortable", "jeans", "t-shirt", "休闲", "日常"],
55
- "运动风格": ["sport", "athletic", "gym", "fitness", "training", "运动", "健身"],
56
- "时尚潮流": ["fashion", "trendy", "stylish", "modern", "chic", "时尚", "潮流"],
57
- "复古风格": ["vintage", "retro", "classic", "traditional", "复古", "经典"],
58
- "街头风格": ["street", "urban", "hip-hop", "edgy", "街头", "嘻哈"],
59
- "优雅风格": ["elegant", "sophisticated", "graceful", "classy", "优雅", "高贵"],
60
- "极简风格": ["minimalist", "clean", "simple", "basic", "极简", "简约"]
 
 
 
 
 
 
61
  }
62
 
63
- self.color_names = {
64
- "红色系": ["红色", "玫瑰红", "深红", "暗红", "鲜红"],
65
- "蓝色系": ["蓝色", "天蓝", "海军蓝", "宝石蓝", "钴蓝"],
66
- "绿色系": ["绿色", "翠绿", "森林绿", "橄榄绿", "苹果绿"],
67
- "黑白灰": ["黑色", "白色", "灰色", "象牙白", "珍珠白", "炭黑"],
68
- "暖色调": ["橙色", "黄色", "粉色", "金黄", "柠檬黄"],
69
- "冷色调": ["紫色", "青色", "薄荷绿", "紫罗兰", "青绿"]
70
  }
71
-
72
- def extract_colors(self, image: Image.Image, n_colors: int = 5) -> List[Dict]:
73
- """提取图像主要颜色"""
74
- try:
75
- # 调整图像大小以提高处理速度
76
- image_small = image.resize((100, 100))
77
- img_array = np.array(image_small)
78
- pixels = img_array.reshape(-1, 3)
79
-
80
- # 过滤极端值
81
- mask = np.all(pixels > 20, axis=1) & np.all(pixels < 235, axis=1)
82
- filtered_pixels = pixels[mask]
83
-
84
- if len(filtered_pixels) < 50:
85
- filtered_pixels = pixels
86
-
87
- # 使用 K-means 聚类
88
- if MODELS_AVAILABLE:
89
- kmeans = KMeans(n_clusters=min(n_colors, len(filtered_pixels)),
90
- random_state=42, n_init=10)
91
- kmeans.fit(filtered_pixels)
92
- colors = kmeans.cluster_centers_
93
- labels = kmeans.labels_
94
- else:
95
- # 简化版本:直接采样
96
- colors = []
97
- step = max(1, len(filtered_pixels) // n_colors)
98
- for i in range(0, len(filtered_pixels), step):
99
- if len(colors) < n_colors:
100
- colors.append(filtered_pixels[i])
101
- colors = np.array(colors)
102
- labels = np.zeros(len(colors))
103
-
104
- color_info = []
105
- for i, color in enumerate(colors):
106
- rgb = color.astype(int)
107
- color_name = self.rgb_to_color_name(rgb)
108
- hex_color = '#{:02x}{:02x}{:02x}'.format(*rgb)
109
-
110
- if MODELS_AVAILABLE:
111
- percentage = np.sum(labels == i) / len(labels) * 100
112
- else:
113
- percentage = 100.0 / len(colors)
114
-
115
- color_info.append({
116
- "name": color_name,
117
- "rgb": rgb.tolist(),
118
- "hex": hex_color,
119
- "percentage": round(percentage, 1)
120
- })
121
-
122
- return sorted(color_info, key=lambda x: x["percentage"], reverse=True)
123
-
124
- except Exception as e:
125
- logger.error(f"颜色提取失败: {e}")
126
- return [{"name": "未知颜色", "rgb": [128, 128, 128], "hex": "#808080", "percentage": 100.0}]
127
-
128
- def rgb_to_color_name(self, rgb: np.ndarray) -> str:
129
- """RGB转颜色名称"""
130
- r, g, b = rgb
131
-
132
- if r > 200 and g > 200 and b > 200:
133
- return "象牙白" if min(r, g, b) > 240 else "珍珠白"
134
- elif r < 50 and g < 50 and b < 50:
135
- return "墨黑" if max(r, g, b) < 30 else "炭黑"
136
- elif r > max(g, b) + 30:
137
- return "玫瑰红" if r > 180 else "深红"
138
- elif g > max(r, b) + 30:
139
- return "翠绿" if g > 180 else "森林绿"
140
- elif b > max(r, g) + 30:
141
- return "天蓝" if b > 180 else "海军蓝"
142
- elif r > 150 and g > 150 and b < 100:
143
- return "金黄" if r > 200 else "暖黄"
144
- elif r > 120 and g < 100 and b > 120:
145
- return "紫罗兰" if r > 150 else "深紫"
146
- else:
147
- return "混合色"
148
-
149
- def analyze_style(self, description: str) -> Dict[str, float]:
150
- """分析时尚风格"""
151
- description_lower = description.lower()
152
- style_scores = {}
153
-
154
- for style, keywords in self.style_keywords.items():
155
- score = sum(1 for keyword in keywords if keyword in description_lower)
156
- if score > 0:
157
- confidence = min(score / len(keywords) * 100, 100)
158
- style_scores[style] = round(confidence, 1)
159
-
160
- # 如果没有匹配到任何风格,给一个默认分析
161
- if not style_scores:
162
- style_scores = {"休闲风格": 60.0, "时尚潮流": 40.0}
163
-
164
- return dict(sorted(style_scores.items(), key=lambda x: x[1], reverse=True))
165
-
166
- class ModelManager:
167
- """AI模型管理器"""
168
-
169
- def __init__(self):
170
- self.models_loaded = False
171
- self.blip_processor = None
172
- self.blip_model = None
173
- self.clip_processor = None
174
- self.clip_model = None
175
- self.sd_pipeline = None
176
- self.controlnet_pipeline = None
177
-
178
- if MODELS_AVAILABLE:
179
- self.load_models()
180
-
181
- def load_models(self):
182
- """加载AI模型"""
183
- try:
184
- logger.info("🔄 开始加载AI模型...")
185
-
186
- # 1. 尝试加载 BLIP 图像理解模型
187
- try:
188
- logger.info("加载 BLIP 图像理解模型...")
189
- self.blip_processor = BlipProcessor.from_pretrained(
190
- "Salesforce/blip-image-captioning-base"
191
- )
192
- self.blip_model = BlipForConditionalGeneration.from_pretrained(
193
- "Salesforce/blip-image-captioning-base",
194
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
195
- )
196
-
197
- if torch.cuda.is_available():
198
- self.blip_model = self.blip_model.to("cuda")
199
-
200
- logger.info("✅ BLIP模型加载成功")
201
- except Exception as e:
202
- logger.warning(f"BLIP模型加载失败: {e}")
203
- # 尝试轻量级替代
204
- try:
205
- self.blip_model = pipeline("image-to-text",
206
- model="nlpconnect/vit-gpt2-image-captioning")
207
- logger.info("✅ 轻量级图像理解模型加载成功")
208
- except Exception as e2:
209
- logger.error(f"所有图像理解模型加载失败: {e2}")
210
-
211
- # 2. 尝试加载 Stable Diffusion
212
- try:
213
- logger.info("加载 Stable Diffusion 设计生成模型...")
214
- self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
215
- "runwayml/stable-diffusion-v1-5",
216
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
217
- safety_checker=None,
218
- requires_safety_checker=False
219
- )
220
-
221
- if torch.cuda.is_available():
222
- self.sd_pipeline = self.sd_pipeline.to("cuda")
223
-
224
- self.sd_pipeline.enable_attention_slicing()
225
- logger.info("✅ Stable Diffusion模型加载成功")
226
- except Exception as e:
227
- logger.warning(f"Stable Diffusion加载失败: {e}")
228
-
229
- self.models_loaded = True
230
- logger.info("🎉 模型加载完成")
231
-
232
- except Exception as e:
233
- logger.error(f"模型加载过程出错: {e}")
234
- self.models_loaded = False
235
-
236
- def generate_description(self, image: Image.Image) -> str:
237
- """生成图像描述"""
238
- if not self.models_loaded:
239
- return "AI模型未就绪,使用基础分析"
240
 
241
- try:
242
- # 使用 BLIP 模型
243
- if self.blip_processor and self.blip_model and hasattr(self.blip_model, 'generate'):
244
- inputs = self.blip_processor(image, return_tensors="pt")
245
- if torch.cuda.is_available() and next(self.blip_model.parameters()).is_cuda:
246
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
247
-
248
- with torch.no_grad():
249
- generated_ids = self.blip_model.generate(
250
- **inputs,
251
- max_length=50,
252
- num_beams=3,
253
- do_sample=True,
254
- temperature=0.7
255
- )
256
-
257
- description = self.blip_processor.decode(generated_ids[0], skip_special_tokens=True)
258
- return description
259
-
260
- # 使用 pipeline 模型
261
- elif hasattr(self.blip_model, '__call__'):
262
- result = self.blip_model(image)
263
- if isinstance(result, list) and len(result) > 0:
264
- return result[0].get('generated_text', '时尚服装图像')
265
-
266
- return "时尚服装分析 - 基础模式"
267
-
268
- except Exception as e:
269
- logger.error(f"描述生成失败: {e}")
270
- return f"图像分析完成 - {str(e)[:50]}"
271
-
272
- def generate_fashion_design(self, prompt: str) -> Optional[Image.Image]:
273
- """生成时尚设计"""
274
- if not self.sd_pipeline:
275
- return self.create_placeholder_image("设计生成功能不可用")
276
 
277
- try:
278
- enhanced_prompt = f"high quality fashion design, {prompt}, professional photography, detailed, 4k"
279
- negative_prompt = "blurry, low quality, distorted, text, watermark, deformed, ugly"
280
-
281
- with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
282
- result = self.sd_pipeline(
283
- prompt=enhanced_prompt,
284
- negative_prompt=negative_prompt,
285
- num_inference_steps=20,
286
- guidance_scale=7.5,
287
- width=512,
288
- height=512
289
- )
290
-
291
- return result.images[0]
292
-
293
- except Exception as e:
294
- logger.error(f"设计生成失败: {e}")
295
- return self.create_placeholder_image(f"生成失败: {str(e)[:30]}")
296
-
297
- def create_placeholder_image(self, text: str) -> Image.Image:
298
- """创建占位图"""
299
- img = Image.new('RGB', (512, 512), color=(245, 245, 250))
300
- draw = ImageDraw.Draw(img)
301
 
302
- # 计算文本位置
303
- text_lines = text.split('\n')
304
- y_start = (512 - len(text_lines) * 30) // 2
305
-
306
- for i, line in enumerate(text_lines):
307
- text_width = len(line) * 8
308
- x = (512 - text_width) // 2
309
- y = y_start + i * 35
310
- draw.text((x, y), line, fill=(100, 100, 100))
311
-
312
- return img
313
-
314
- def cleanup_memory(self):
315
- """清理内存"""
316
- try:
317
- if torch.cuda.is_available():
318
- torch.cuda.empty_cache()
319
-
320
- import gc
321
- gc.collect()
322
-
323
- return "✅ 内存清理完成"
324
- except Exception as e:
325
- return f"❌ 内存清理失败: {e}"
326
-
327
- # 全局实例
328
- fashion_analyzer = FashionAnalyzer()
329
- model_manager = ModelManager()
330
 
331
- # 主要功能函数
332
- def analyze_fashion_image(image: Image.Image) -> Dict:
333
- """分析时尚图像"""
334
- if image is None:
335
- return {"error": "请上传图像"}
336
-
337
  try:
338
- logger.info("开始时尚图像分析...")
339
- start_time = time.time()
340
-
341
- # 1. 生成图像描述
342
- description = model_manager.generate_description(image)
343
-
344
- # 2. 色彩分析
345
- colors = fashion_analyzer.extract_colors(image, n_colors=5)
346
-
347
- # 3. 风格分析
348
- styles = fashion_analyzer.analyze_style(description)
349
- primary_style = list(styles.keys())[0] if styles else "现代时尚"
350
-
351
- # 4. 综合分析
352
- analysis_time = round(time.time() - start_time, 2)
353
-
354
- result = {
355
- "image_description": description,
356
- "color_analysis": colors,
357
- "style_analysis": styles,
358
- "primary_style": primary_style,
359
- "main_colors": [c["name"] for c in colors[:3]],
360
- "analysis_time": f"{analysis_time}秒",
361
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
362
- "ai_model_status": "✅ 已连接" if model_manager.models_loaded else "⚠️ 基础模式"
363
  }
364
 
365
- logger.info(f"分析完成,耗时 {analysis_time}秒")
366
- return result
367
-
368
- except Exception as e:
369
- logger.error(f"分析失败: {e}")
370
- return {"error": f"分析过程出错: {str(e)}"}
371
-
372
- def generate_design_suggestions(analysis_result: Dict) -> Dict[str, str]:
373
- """生成设计建议"""
374
- if "error" in analysis_result:
375
- return {"基础建议": "请先完成图像分析"}
376
-
377
- primary_style = analysis_result.get("primary_style", "现代时尚")
378
- main_colors = analysis_result.get("main_colors", ["经典色"])
379
-
380
- suggestions = {
381
- f"优化{primary_style}": f"保持{primary_style}特色,突出{main_colors[0]}主色调",
382
- f"色彩增强版": f"基于{primary_style}风格,强化{main_colors[0]}和{main_colors[1] if len(main_colors) > 1 else '经典色'}搭配",
383
- f"现代融合": f"将{primary_style}与当代设计元素结合",
384
- f"个性定制": f"专属{primary_style}风格的个性化设计",
385
- f"场景适配": f"适合多种场合的{primary_style}变化"
386
- }
387
-
388
- return suggestions
389
-
390
- def generate_designs(suggestion: str, analysis_result: Dict, progress=gr.Progress()) -> List[Image.Image]:
391
- """生成设计方案"""
392
- if not suggestion or "error" in analysis_result:
393
- return [model_manager.create_placeholder_image("请先选择设计建议")]
394
-
395
- try:
396
- primary_style = analysis_result.get("primary_style", "现代时尚")
397
- main_colors = analysis_result.get("main_colors", ["经典色"])
398
 
399
- designs = []
400
- design_prompts = [
401
- f"{primary_style} style fashion, {main_colors[0]} color, elegant design",
402
- f"modern {primary_style} clothing, {main_colors[0]} and {main_colors[1] if len(main_colors) > 1 else 'neutral'} colors",
403
- f"professional {primary_style} outfit, premium materials, {main_colors[0]} accent",
404
- f"contemporary {primary_style} fashion, artistic design, {main_colors[0]} theme"
405
- ]
406
 
407
- for i, prompt in enumerate(design_prompts):
408
- if progress:
409
- progress(i / len(design_prompts), f"生成设计方案 {i+1}/{len(design_prompts)}")
410
-
411
- design_image = model_manager.generate_fashion_design(prompt)
412
- designs.append(design_image)
 
 
 
 
 
 
413
 
414
- if progress:
415
- progress(1.0, "设计生成完成")
 
416
 
417
- return designs
418
 
419
  except Exception as e:
420
- logger.error(f"设计生成失败: {e}")
421
- return [model_manager.create_placeholder_image(f"生成失败: {str(e)[:30]}")]
422
 
423
- def create_3d_fitting(selected_design: str) -> Image.Image:
424
- """创建3D试穿效果"""
425
- if not selected_design:
426
- return model_manager.create_placeholder_image("请先选择设计方案")
427
-
428
  try:
429
- # 创建3D试穿提示
430
- fitting_prompt = f"3D virtual fashion model wearing {selected_design}, full body view, professional studio lighting, photorealistic"
431
 
432
- # 生成3D试穿图像
433
- fitting_image = model_manager.generate_fashion_design(fitting_prompt)
434
 
435
- return fitting_image if fitting_image else model_manager.create_placeholder_image("3D试穿生成完成")
 
 
 
 
 
 
 
436
 
437
  except Exception as e:
438
- logger.error(f"3D试穿生成失败: {e}")
439
- return model_manager.create_placeholder_image("3D试穿生成失败")
440
 
441
- # Gradio 界面
442
- def create_fashion_interface():
443
- """创建时尚AI界面"""
444
-
445
- custom_css = """
446
- .gradio-container {
447
- max-width: 1200px;
448
- margin: 0 auto;
449
- font-family: 'Inter', sans-serif;
450
- }
451
- .header {
452
- text-align: center;
453
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
454
- color: white;
455
- padding: 20px;
456
- border-radius: 15px;
457
- margin-bottom: 20px;
458
- }
459
- .status-info {
460
- background: #f8f9fa;
461
- padding: 10px;
462
- border-radius: 8px;
463
- border-left: 4px solid #28a745;
464
- margin: 10px 0;
465
- }
466
- """
467
 
468
- with gr.Blocks(title="AI时尚设计师", theme=gr.themes.Soft(), css=custom_css) as demo:
469
-
470
- # 头部
471
- gr.HTML(f"""
472
- <div class="header">
473
- <h1>🎨 AI时尚设计师</h1>
474
- <p>智能图像分析 • 个性化设计建议 • AI生成时尚方案</p>
475
- <div class="status-info">
476
- <strong>系统状态:</strong>
477
- {'🟢 GPU可用' if torch.cuda.is_available() else '🟡 CPU模式'} |
478
- {'✅ AI模型就绪' if model_manager.models_loaded else '⚠️ 基础模式'} |
479
- {'🚀 Spaces GPU' if SPACES_GPU_AVAILABLE else '💻 标准环境'}
480
- </div>
481
- </div>
482
- """)
483
-
484
- # 主界面
485
- with gr.Row():
486
- # 左侧:图像上传和分析
487
- with gr.Column(scale=1):
488
- gr.Markdown("## 📸 图像分析")
489
-
490
- image_input = gr.Image(
491
- type="pil",
492
- label="上传时尚图片",
493
- height=300
494
- )
495
-
496
- analyze_btn = gr.Button(
497
- "🔍 AI智能分析",
498
- variant="primary",
499
- size="lg"
500
- )
501
-
502
- # 分析状态
503
- analysis_status = gr.Textbox(
504
- label="分析状态",
505
- value="等待图片上传...",
506
- interactive=False
507
- )
508
-
509
- # 右侧:分析结果
510
- with gr.Column(scale=2):
511
- gr.Markdown("## 📊 分析结果")
512
-
513
- with gr.Tabs():
514
- with gr.Tab("🔍 详细分析"):
515
- analysis_output = gr.JSON(label="AI分析报告")
516
-
517
- with gr.Tab("🎨 色彩分析"):
518
- color_gallery = gr.DataFrame(
519
- headers=["颜色名称", "RGB值", "十六进制", "占比%"],
520
- label="色彩详细信息"
521
- )
522
-
523
- # 设计建议部分
524
- gr.Markdown("## 💡 个性化设计建议")
525
 
526
- with gr.Row():
527
- suggestions_radio = gr.Radio(
528
- label="选择设计方向",
529
- interactive=True
530
- )
531
- generate_btn = gr.Button(
532
- "🚀 生成设计���案",
533
- variant="primary"
534
- )
535
-
536
- # 设计结果
537
- with gr.Tabs():
538
- with gr.Tab("🎯 设计方案"):
539
- designs_gallery = gr.Gallery(
540
- label="AI生成的设计方案",
541
- columns=2,
542
- rows=2,
543
- height=400
544
- )
545
-
546
- design_choice = gr.Radio(
547
- label="选择方案进行3D试穿",
548
- interactive=True
549
- )
550
-
551
- fitting_btn = gr.Button(
552
- "👤 生成3D试穿",
553
- variant="primary"
554
- )
555
-
556
- with gr.Tab("👥 3D试穿"):
557
- fitting_output = gr.Image(
558
- label="3D虚拟试穿效果",
559
- height=500
560
- )
561
 
562
- # 系统控制
563
- with gr.Accordion("⚙️ 系统控制", open=False):
564
- with gr.Row():
565
- cleanup_btn = gr.Button("🧹 清理内存")
566
- memory_status = gr.Textbox(label="内存状态", interactive=False)
567
 
568
- # 状态存储
569
- analysis_state = gr.State({})
 
 
570
 
571
- # 事件处理
572
- def process_analysis(image):
573
- if image is None:
574
- return {}, gr.Radio(choices=[]), [], "❌ 请先上传图片", {}
575
-
576
- # 执行分析
577
- result = analyze_fashion_image(image)
578
-
579
- if "error" in result:
580
- return result, gr.Radio(choices=[]), [], f"❌ {result['error']}", result
581
-
582
- # 生成建议
583
- suggestions = generate_design_suggestions(result)
584
- choices = list(suggestions.keys())
585
-
586
- # 准备色彩数据
587
- color_data = []
588
- for color in result.get("color_analysis", []):
589
- color_data.append([
590
- color["name"],
591
- str(color["rgb"]),
592
- color["hex"],
593
- f"{color['percentage']}%"
594
- ])
595
-
596
- status = f"✅ 分析完成 - 耗时 {result.get('analysis_time', '未知')}"
597
-
598
- return (
599
- result,
600
- gr.Radio(choices=choices, value=choices[0] if choices else None),
601
- color_data,
602
- status,
603
- result
604
- )
605
 
 
606
  analyze_btn.click(
607
- fn=process_analysis,
608
  inputs=[image_input],
609
- outputs=[
610
- analysis_output,
611
- suggestions_radio,
612
- color_gallery,
613
- analysis_status,
614
- analysis_state
615
- ]
616
  )
617
 
618
- # 设计生成
619
- def handle_design_generation(suggestion, analysis_result):
620
- if not suggestion or "error" in analysis_result:
621
- return [], gr.Radio(choices=[])
622
-
623
- designs = generate_designs(suggestion, analysis_result)
624
- choices = [f"{suggestion} - 方案{i+1}" for i in range(len(designs))]
625
-
626
- return designs, gr.Radio(choices=choices, value=choices[0] if choices else None)
627
-
628
- generate_btn.click(
629
- fn=handle_design_generation,
630
- inputs=[suggestions_radio, analysis_state],
631
  outputs=[designs_gallery, design_choice]
632
  )
633
 
634
- # 3D试穿
635
- fitting_btn.click(
636
- fn=create_3d_fitting,
637
  inputs=[design_choice],
638
- outputs=[fitting_output]
639
- )
640
-
641
- # 内存清理
642
- cleanup_btn.click(
643
- fn=model_manager.cleanup_memory,
644
- inputs=[],
645
- outputs=[memory_status]
646
  )
647
-
648
- # 底部信息
649
- gr.Markdown("""
650
- ---
651
- ### 🔧 技术说明
652
-
653
- **AI技术栈:**
654
- - 🔤 BLIP: 专业图像理解与描述生成
655
- - 🎨 Stable Diffusion: 高质量时尚设计生成
656
- - 🧮 K-means聚类: 智能色彩分析
657
- - 📊 风格识别: 多维度时尚风格评估
658
-
659
- **系统特点:**
660
- - ✅ 单文件部署,无模块导入错误
661
- - 🚀 自动GPU/CPU适配
662
- - 🛡️ 完善的错误处理机制
663
- - 📱 响应式用户界面
664
-
665
- > 💡 **使用提示**: 首次运行会下载AI模型,请耐心等待。生成过程可能需要1-3分钟。
666
- """)
667
 
668
  return demo
669
 
670
- # 主函数
671
- def main():
672
- """应用主入口"""
673
- try:
674
- logger.info("🚀 启动AI时尚设计师应用...")
675
- logger.info(f"PyTorch版本: {torch.__version__}")
676
- logger.info(f"CUDA可用: {torch.cuda.is_available()}")
677
- logger.info(f"AI模型库可用: {MODELS_AVAILABLE}")
678
-
679
- # 创建界面
680
- demo = create_fashion_interface()
681
-
682
- # 配置启动参数
683
- demo.queue(
684
- concurrency_count=2 if torch.cuda.is_available() else 1,
685
- max_size=10
686
- )
687
-
688
- # 启动应用
689
- demo.launch(
690
- server_name="0.0.0.0",
691
- server_port=int(os.environ.get("PORT", 7860)),
692
- share=False,
693
- show_error=True,
694
- debug=False
695
- )
696
-
697
- except Exception as e:
698
- logger.error(f"应用启动失败: {e}")
699
- print(f"启动错误: {e}")
700
-
701
- # 创建最小化的错误页面
702
- with gr.Blocks() as error_demo:
703
- gr.HTML(f"""
704
- <div style="text-align: center; padding: 50px;">
705
- <h1>❌ 应用启动失败</h1>
706
- <p><strong>错误信息:</strong> {str(e)}</p>
707
- <p><strong>PyTorch可用:</strong> {torch.__version__ if 'torch' in globals() else '未安装'}</p>
708
- <p><strong>CUDA可用:</strong> {torch.cuda.is_available() if 'torch' in globals() else '未知'}</p>
709
- <hr>
710
- <h2>🛠️ 故障排除建议:</h2>
711
- <ol style="text-align: left; max-width: 600px; margin: 0 auto;">
712
- <li>检查 requirements.txt 中的依赖是否正确安装</li>
713
- <li>确认 Hugging Face Spaces 环境配置</li>
714
- <li>检查 GPU 资源是否可用</li>
715
- <li>查看完整的错误日志</li>
716
- </ol>
717
- </div>
718
- """)
719
-
720
- error_demo.launch()
721
-
722
  if __name__ == "__main__":
723
- main()
 
 
1
+ # app.py (Gradio界面)
 
 
2
  import gradio as gr
3
+ from main import app
4
+ import requests
5
+ from PIL import Image
6
+ import json
 
 
 
 
 
 
 
 
 
7
 
8
+ # 导入模型管理器
9
+ from models.model_manager import ModelManager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # 初始化模型管理器
12
+ model_manager = ModelManager()
 
 
 
 
 
 
13
 
14
+ def upload_and_analyze(image_path):
15
+ """分析上传的图片"""
16
+ try:
17
+ if image_path is None:
18
+ return {}, {}, []
19
+
20
+ # 打开图片
21
+ image = Image.open(image_path)
22
+
23
+ # 生成图像描述
24
+ caption = model_manager.generate_caption(image)
25
+
26
+ # 模拟风格分析结果
27
+ analysis_result = {
28
+ "图像描述": caption,
29
+ "检测到的颜色": ["蓝色", "白色", "黑色"],
30
+ "风格类型": "休闲风",
31
+ "服装类别": "上衣",
32
+ "适合场景": ["日常", "休闲", "约会"]
33
  }
34
 
35
+ # 生成设计建议
36
+ suggestions = {
37
+ "建议1": "现代简约风格搭配",
38
+ "建议2": "复古经典款式",
39
+ "建议3": "运动休闲风格",
40
+ "建议4": "商务正装风格"
 
41
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # 创建选择选项
44
+ choices = list(suggestions.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ return analysis_result, suggestions, gr.Radio(choices=choices, value=choices[0] if choices else None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ except Exception as e:
49
+ error_result = {"错误": f"分析失败: {str(e)}"}
50
+ return error_result, {}, []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ def generate_designs(selected_suggestion):
53
+ """根据选择的建议生成设计"""
 
 
 
 
54
  try:
55
+ if not selected_suggestion:
56
+ return [], gr.Radio(choices=[])
57
+
58
+ # 生成设计图像的提示词
59
+ design_prompts = {
60
+ "建议1": "modern minimalist clothing design, clean lines, neutral colors",
61
+ "建议2": "vintage classic fashion design, retro style, elegant",
62
+ "建议3": "sporty casual wear design, comfortable, athletic",
63
+ "建议4": "business formal attire, professional, sophisticated"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  }
65
 
66
+ prompt = design_prompts.get(selected_suggestion, "fashion design, stylish clothing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # 生成设计图像
69
+ design_images = []
70
+ design_choices = []
 
 
 
 
71
 
72
+ for i in range(3): # 生成3个设计
73
+ try:
74
+ image = model_manager.generate_image(
75
+ prompt=f"{prompt}, design {i+1}",
76
+ negative_prompt="blurry, low quality, distorted",
77
+ num_inference_steps=20
78
+ )
79
+ if image:
80
+ design_images.append(image)
81
+ design_choices.append(f"设计方案 {i+1}")
82
+ except Exception as e:
83
+ print(f"生成设计 {i+1} 失败: {e}")
84
 
85
+ # 如果没有成功生成图像,返回空结果
86
+ if not design_images:
87
+ return [], gr.Radio(choices=[])
88
 
89
+ return design_images, gr.Radio(choices=design_choices, value=design_choices[0] if design_choices else None)
90
 
91
  except Exception as e:
92
+ print(f"设计生成错误: {e}")
93
+ return [], gr.Radio(choices=[])
94
 
95
+ def generate_3d_fitting(selected_design):
96
+ """生成3D试穿效果"""
 
 
 
97
  try:
98
+ if not selected_design:
99
+ return None
100
 
101
+ # 生成3D试穿效果的提示词
102
+ fitting_prompt = f"3D fashion fitting, virtual try-on, {selected_design}, realistic human model"
103
 
104
+ # 使用模型生成3D试穿图像
105
+ fitting_image = model_manager.generate_image(
106
+ prompt=fitting_prompt,
107
+ negative_prompt="blurry, distorted, low quality, unrealistic",
108
+ num_inference_steps=25
109
+ )
110
+
111
+ return fitting_image
112
 
113
  except Exception as e:
114
+ print(f"3D试穿生成错误: {e}")
115
+ return None
116
 
117
+ def create_gradio_interface():
118
+ """创建Gradio用户界面"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ with gr.Blocks(title="AI时尚设计师") as demo:
121
+ gr.Markdown("# 🎨 AI时尚设计师")
122
+ gr.Markdown("上传图片,获得专业的服装设计建议和3D试穿效果")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ with gr.Tab("1. 图片上传与分析"):
125
+ image_input = gr.Image(type="filepath", label="上传参考图片")
126
+ analyze_btn = gr.Button("分析风格")
127
+ analysis_output = gr.JSON(label="风格分析结果")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ with gr.Tab("2. 设计建议"):
130
+ suggestions_output = gr.JSON(label="设计建议")
131
+ suggestion_choice = gr.Radio(label="选择设计建议")
132
+ generate_designs_btn = gr.Button("生成样衣设计")
 
133
 
134
+ with gr.Tab("3. 样衣设计"):
135
+ designs_gallery = gr.Gallery(label="样衣设计图")
136
+ design_choice = gr.Radio(label="选择设计")
137
+ generate_3d_btn = gr.Button("生成3D试穿")
138
 
139
+ with gr.Tab("4. 3D试穿"):
140
+ fitting_result = gr.Image(label="3D试穿效果")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ # 事件绑定
143
  analyze_btn.click(
144
+ fn=upload_and_analyze,
145
  inputs=[image_input],
146
+ outputs=[analysis_output, suggestions_output, suggestion_choice]
 
 
 
 
 
 
147
  )
148
 
149
+ generate_designs_btn.click(
150
+ fn=generate_designs,
151
+ inputs=[suggestion_choice],
 
 
 
 
 
 
 
 
 
 
152
  outputs=[designs_gallery, design_choice]
153
  )
154
 
155
+ generate_3d_btn.click(
156
+ fn=generate_3d_fitting,
 
157
  inputs=[design_choice],
158
+ outputs=[fitting_result]
 
 
 
 
 
 
 
159
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  return demo
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if __name__ == "__main__":
164
+ demo = create_gradio_interface()
165
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)