NEW / app.py
Humphreykowl's picture
app.py
acec099 verified
# app.py (Fixed AI Fashion Designer for Hugging Face Spaces)
import gradio as gr
import requests
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from sklearn.cluster import KMeans
import time
import random
import os
import torch
import logging
from typing import Dict, List, Tuple, Optional
import json
from datetime import datetime
# Hugging Face specific imports
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
import cv2
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class FashionAnalyzer:
"""Enhanced fashion analysis engine with proper image understanding"""
def __init__(self):
self.style_keywords = {
"商务正装": ["suit", "formal", "business", "office", "professional", "tie", "blazer", "西装", "正装", "商务", "dress shirt", "formal wear"],
"休闲风": ["casual", "relaxed", "comfortable", "everyday", "jeans", "t-shirt", "休闲", "日常", "weekend", "laid-back"],
"运动风": ["sport", "athletic", "gym", "fitness", "running", "training", "运动", "健身", "activewear", "yoga"],
"时尚潮流": ["fashion", "trendy", "stylish", "modern", "chic", "designer", "时尚", "潮流", "avant-garde", "runway"],
"复古风": ["vintage", "retro", "classic", "traditional", "old-fashioned", "复古", "经典", "throwback", "nostalgic"],
"街头风": ["street", "urban", "hip-hop", "cool", "edgy", "街头", "嘻哈", "streetwear", "grunge"],
"优雅风": ["elegant", "sophisticated", "graceful", "refined", "classy", "优雅", "高贵", "glamorous", "luxurious"],
"波西米亚风": ["bohemian", "boho", "free-spirited", "artistic", "flowing", "ethnic", "波西米亚", "民族风"],
"极简风": ["minimalist", "clean", "simple", "basic", "understated", "极简", "简约", "nordic"]
}
self.color_palette = {
"春季色彩": ["粉色", "嫩绿", "天蓝", "柠檬黄", "薰衣草紫"],
"夏季色彩": ["白色", "海军蓝", "珊瑚色", "薄荷绿", "阳光橙"],
"秋季色彩": ["棕色", "橙色", "深红", "金黄", "橄榄绿"],
"冬季色彩": ["黑色", "深灰", "酒红", "深蓝", "银色"]
}
def extract_advanced_colors(self, image: Image.Image, n_colors: int = 5) -> List[Dict]:
"""Extract dominant colors with more detailed analysis"""
try:
# Resize for processing efficiency
image_resized = image.resize((150, 150))
img_array = np.array(image_resized)
pixels = img_array.reshape(-1, 3)
# Remove near-black and near-white pixels for better color analysis
mask = np.all(pixels > 30, axis=1) & np.all(pixels < 225, axis=1)
filtered_pixels = pixels[mask]
if len(filtered_pixels) < 50: # Fallback if too few pixels
filtered_pixels = pixels
# K-means clustering
n_colors = min(n_colors, len(filtered_pixels))
kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
kmeans.fit(filtered_pixels)
colors_info = []
for i, color in enumerate(kmeans.cluster_centers_):
color_rgb = color.astype(int)
color_name = self.rgb_to_advanced_color_name(color_rgb)
color_hex = '#{:02x}{:02x}{:02x}'.format(*color_rgb)
# Calculate color properties
brightness = np.mean(color_rgb)
saturation = (np.max(color_rgb) - np.min(color_rgb)) / 255.0
colors_info.append({
"name": color_name,
"rgb": color_rgb.tolist(),
"hex": color_hex,
"brightness": round(brightness, 2),
"saturation": round(saturation, 2),
"season_match": self.get_season_match(color_name)
})
return colors_info
except Exception as e:
logger.error(f"Color extraction failed: {e}")
return [{"name": "未知颜色", "rgb": [128, 128, 128], "hex": "#808080", "brightness": 0.5, "saturation": 0}]
def rgb_to_advanced_color_name(self, rgb: np.ndarray) -> str:
"""Enhanced RGB to color name mapping"""
r, g, b = rgb
# Advanced color detection
if r > 200 and g > 200 and b > 200:
return "象牙白" if min(r, g, b) > 240 else "米白色"
elif r < 50 and g < 50 and b < 50:
return "墨黑" if max(r, g, b) < 30 else "炭灰"
elif r > max(g, b) + 30:
if r > 180:
return "鲜红" if g < 100 and b < 100 else "珊瑚红"
elif r > 120:
return "深红" if g < 80 and b < 80 else "玫瑰红"
else:
return "暗红"
elif g > max(r, b) + 30:
if g > 180:
return "翠绿" if r < 100 and b < 100 else "苹果绿"
elif g > 120:
return "森林绿" if r < 80 and b < 80 else "橄榄绿"
else:
return "深绿"
elif b > max(r, g) + 30:
if b > 180:
return "天蓝" if r < 100 and g < 150 else "钴蓝"
elif b > 120:
return "海军蓝" if r < 80 and g < 80 else "宝蓝"
else:
return "深蓝"
elif r > 150 and g > 150 and b < 100:
return "柠檬黄" if r > 200 and g > 200 else "金黄"
elif r > 120 and g < 100 and b > 120:
return "紫罗兰" if r > 150 and b > 150 else "深紫"
elif g > 120 and b > 120 and r < 100:
return "青绿" if g > 150 and b > 150 else "青色"
elif abs(r - g) < 30 and abs(g - b) < 30:
if r > 150:
return "浅灰"
elif r > 100:
return "中灰"
else:
return "深灰"
else:
return "混合色"
def get_season_match(self, color_name: str) -> str:
"""Determine which season a color best matches"""
for season, colors in self.color_palette.items():
if any(season_color in color_name for season_color in colors):
return season
return "四季通用"
def analyze_style_confidence_from_image(self, image: Image.Image, caption: str) -> Dict[str, float]:
"""Analyze style confidence using both image features and caption"""
caption_lower = caption.lower()
style_scores = {}
# 1. Caption-based analysis (existing logic)
for style, keywords in self.style_keywords.items():
score = sum(1 for keyword in keywords if keyword in caption_lower)
confidence = min(score / len(keywords) * 100, 100)
if confidence > 0:
style_scores[style] = confidence
# 2. Image-based analysis using color and shape features
colors_info = self.extract_advanced_colors(image, n_colors=3)
# Business style indicators: neutral colors, structured shapes
if any(color["name"] in ["黑色", "深灰", "海军蓝", "白色"] for color in colors_info):
style_scores["商务正装"] = style_scores.get("商务正装", 0) + 20
# Casual style indicators: bright colors, relaxed tones
if any(color["name"] in ["天蓝", "苹果绿", "米白色"] for color in colors_info):
style_scores["休闲风"] = style_scores.get("休闲风", 0) + 15
# Fashion style indicators: vibrant colors, high saturation
avg_saturation = np.mean([color["saturation"] for color in colors_info])
if avg_saturation > 0.6:
style_scores["时尚潮流"] = style_scores.get("时尚潮流", 0) + 25
# Normalize scores to 0-100 range
if style_scores:
max_score = max(style_scores.values())
if max_score > 100:
style_scores = {k: min(v, 100) for k, v in style_scores.items()}
# Sort by confidence
return dict(sorted(style_scores.items(), key=lambda x: x[1], reverse=True))
class ModelManager:
"""Enhanced model management with Hugging Face integration"""
def __init__(self):
self.caption_model = None
self.caption_processor = None
self.sd_pipeline = None
self.controlnet_pipeline = None
self._load_models()
def _load_models(self):
"""Load AI models with Hugging Face integration"""
try:
logger.info("Loading models from Hugging Face...")
# 1. Load BLIP for image captioning (FIXED: Now actually using image understanding)
try:
self.caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
logger.info("✅ BLIP image captioning model loaded")
except Exception as e:
logger.warning(f"Failed to load BLIP model: {e}")
# Fallback to a smaller model for Hugging Face Spaces
try:
self.caption_model = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
logger.info("✅ Fallback image captioning model loaded")
except Exception as e2:
logger.error(f"Failed to load fallback model: {e2}")
# 2. Load Stable Diffusion for design generation
try:
# Use a lightweight SD model suitable for Spaces
model_id = "runwayml/stable-diffusion-v1-5"
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
if torch.cuda.is_available():
self.sd_pipeline = self.sd_pipeline.to("cuda")
# Enable memory efficient attention
self.sd_pipeline.enable_attention_slicing()
logger.info("✅ Stable Diffusion pipeline loaded")
except Exception as e:
logger.warning(f"Failed to load SD pipeline: {e}")
# 3. Load ControlNet for better 3D fitting (FIXED: Better 3D model)
try:
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
if torch.cuda.is_available():
self.controlnet_pipeline = self.controlnet_pipeline.to("cuda")
self.controlnet_pipeline.enable_attention_slicing()
logger.info("✅ ControlNet pipeline loaded for enhanced 3D fitting")
except Exception as e:
logger.warning(f"Failed to load ControlNet: {e}")
self._models_ready = True
logger.info("All models initialized successfully")
except Exception as e:
logger.error(f"Model loading failed: {e}")
self._models_ready = False
def generate_caption(self, image: Image.Image) -> str:
"""Generate image caption using actual AI model (FIXED)"""
try:
if not self._models_ready:
return "图像分析暂不可用 - 模型未加载"
# Method 1: Use BLIP with processor
if self.caption_processor and self.caption_model:
inputs = self.caption_processor(image, return_tensors="pt")
out = self.caption_model.generate(**inputs, max_length=50)
caption = self.caption_processor.decode(out[0], skip_special_tokens=True)
return caption
# Method 2: Use pipeline fallback
elif hasattr(self.caption_model, '__call__'):
result = self.caption_model(image)
if isinstance(result, list) and len(result) > 0:
return result[0].get('generated_text', 'Fashion image analysis')
# Fallback: basic analysis
return "时尚服装图片 - 需要进一步分析"
except Exception as e:
logger.error(f"Caption generation failed: {e}")
return f"图像描述生成失败: {str(e)}"
def generate_image(self, prompt: str, negative_prompt: str = "", **kwargs) -> Optional[Image.Image]:
"""Generate design image using Stable Diffusion (FIXED: Uses analysis results)"""
try:
if not self._models_ready or not self.sd_pipeline:
return self._create_placeholder_image("Stable Diffusion未就绪")
# Enhanced prompt engineering based on analysis
enhanced_prompt = f"high quality fashion design, {prompt}, professional photography, detailed, 4k"
enhanced_negative = f"blurry, low quality, distorted, text, watermark, deformed, ugly, {negative_prompt}"
# Generate with optimized parameters for Spaces
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
result = self.sd_pipeline(
prompt=enhanced_prompt,
negative_prompt=enhanced_negative,
num_inference_steps=kwargs.get('num_inference_steps', 20), # Reduced for speed
guidance_scale=kwargs.get('guidance_scale', 7.5),
width=kwargs.get('width', 512),
height=kwargs.get('height', 512),
generator=torch.Generator().manual_seed(random.randint(0, 2147483647))
)
return result.images[0]
except Exception as e:
logger.error(f"Image generation failed: {e}")
return self._create_placeholder_image(f"生成失败: {str(e)[:30]}...")
def generate_3d_fitting(self, prompt: str, pose_image: Optional[Image.Image] = None) -> Optional[Image.Image]:
"""Generate enhanced 3D fitting using ControlNet (FIXED: Better 3D precision)"""
try:
if not self._models_ready:
return self._create_placeholder_image("3D模型未就绪")
# If ControlNet is available and we have pose guidance
if self.controlnet_pipeline and pose_image:
enhanced_prompt = f"3D virtual fashion model wearing {prompt}, photorealistic, professional studio lighting, full body, fashion photography, detailed textures"
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
result = self.controlnet_pipeline(
prompt=enhanced_prompt,
image=pose_image,
num_inference_steps=25,
guidance_scale=8.0,
controlnet_conditioning_scale=1.0
)
return result.images[0]
# Fallback to regular SD with enhanced 3D prompt
elif self.sd_pipeline:
enhanced_3d_prompt = f"3D rendered fashion model, {prompt}, volumetric lighting, realistic human proportions, fashion photography, professional quality, detailed fabric textures, studio lighting"
return self.generate_image(
prompt=enhanced_3d_prompt,
negative_prompt="flat, 2D, cartoon, anime, low quality, distorted proportions, bad anatomy",
num_inference_steps=30,
guidance_scale=8.5,
width=512,
height=768 # Taller for full body
)
return self._create_placeholder_image("3D生成功能不可用")
except Exception as e:
logger.error(f"3D fitting generation failed: {e}")
return self._create_placeholder_image("3D生成失败")
def _create_placeholder_image(self, text: str) -> Image.Image:
"""Create placeholder image with text"""
img = Image.new('RGB', (512, 512), color=(240, 240, 245))
draw = ImageDraw.Draw(img)
# Calculate text position
text_lines = text.split('\n')
total_height = len(text_lines) * 20
start_y = (512 - total_height) // 2
for i, line in enumerate(text_lines):
text_width = len(line) * 8 # Approximate character width
x = (512 - text_width) // 2
y = start_y + i * 25
draw.text((x, y), line, fill=(100, 100, 100))
return img
def create_pose_reference(self, width: int = 512, height: int = 768) -> Image.Image:
"""Create a simple pose reference for ControlNet"""
# Create a basic human pose outline
img = np.zeros((height, width, 3), dtype=np.uint8)
# Define key points for a standing pose (simplified)
# Head
cv2.circle(img, (width//2, height//6), 30, (255, 255, 255), 2)
# Body
cv2.line(img, (width//2, height//6 + 30), (width//2, height//2), (255, 255, 255), 3)
# Arms
cv2.line(img, (width//2, height//3), (width//2 - 60, height//2 - 20), (255, 255, 255), 2)
cv2.line(img, (width//2, height//3), (width//2 + 60, height//2 - 20), (255, 255, 255), 2)
# Legs
cv2.line(img, (width//2, height//2), (width//2 - 30, height - 50), (255, 255, 255), 3)
cv2.line(img, (width//2, height//2), (width//2 + 30, height - 50), (255, 255, 255), 3)
return Image.fromarray(img)
def cleanup(self):
"""Clean up GPU memory for Hugging Face Spaces"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Force garbage collection
import gc
gc.collect()
logger.info("GPU memory and cache cleaned up")
return "✅ 内存已清理"
except Exception as e:
logger.error(f"Cleanup failed: {e}")
return f"❌ 清理失败: {str(e)}"
# Global instances
fashion_analyzer = FashionAnalyzer()
model_manager = ModelManager()
def upload_and_analyze(image_path):
"""Enhanced image analysis with proper AI integration (FIXED)"""
try:
if image_path is None:
return {}, {}, gr.Radio(choices=[])
logger.info(f"Analyzing image: {image_path}")
# Load and validate image
try:
image = Image.open(image_path).convert('RGB')
except Exception as e:
return {"错误": f"无法打开图像文件: {str(e)}"}, {}, gr.Radio(choices=[])
# Generate comprehensive analysis
analysis_start = time.time()
# 1. Generate caption using actual AI model (FIXED)
caption = model_manager.generate_caption(image)
logger.info(f"Generated caption: {caption}")
# 2. Advanced color analysis from actual image
colors_info = fashion_analyzer.extract_advanced_colors(image)
# 3. Style analysis using BOTH image and caption (FIXED)
style_scores = fashion_analyzer.analyze_style_confidence_from_image(image, caption)
primary_style = list(style_scores.keys())[0] if style_scores else "休闲风"
# 4. Enhanced clothing category detection
clothing_category = infer_clothing_category(caption)
# 5. Scene recommendations based on analysis
suitable_scenes = get_enhanced_suitable_scenes(primary_style, colors_info)
# 6. Fashion trends analysis
trend_analysis = analyze_fashion_trends(primary_style, colors_info)
analysis_time = round(time.time() - analysis_start, 2)
# Comprehensive analysis result
analysis_result = {
"图像描述": caption,
"主要颜色": [color["name"] for color in colors_info[:3]],
"详细色彩分析": colors_info,
"风格分析": style_scores,
"主要风格": primary_style,
"服装类别": clothing_category,
"适合场景": suitable_scenes,
"流行趋势匹配": trend_analysis,
"图像尺寸": f"{image.width} x {image.height}",
"分析耗时": f"{analysis_time}秒",
"分析时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"AI模型状态": "✅ 已连接" if model_manager._models_ready else "❌ 离线模式"
}
# Generate personalized suggestions BASED ON ANALYSIS (FIXED)
suggestions = generate_enhanced_suggestions(analysis_result)
# Create choice options
choices = list(suggestions.keys())
logger.info(f"Analysis completed in {analysis_time}s")
return analysis_result, suggestions, gr.Radio(choices=choices, value=choices[0] if choices else None)
except Exception as e:
logger.error(f"Analysis failed: {e}")
error_result = {"错误": f"分析过程中出现错误: {str(e)}"}
return error_result, {}, gr.Radio(choices=[])
def generate_designs(selected_suggestion, analysis_result, progress=gr.Progress()):
"""Enhanced design generation using analysis results (FIXED)"""
try:
if not selected_suggestion:
return [], gr.Radio(choices=[])
logger.info(f"Generating designs for: {selected_suggestion}")
progress(0.1, desc="分析设计需求...")
# Extract key information from analysis (FIXED: Now uses actual analysis)
primary_style = analysis_result.get("主要风格", "休闲风") if analysis_result else "休闲风"
main_colors = analysis_result.get("主要颜色", ["蓝色"]) if analysis_result else ["蓝色"]
clothing_category = analysis_result.get("服装类别", "时尚单品") if analysis_result else "时尚单品"
# Generate design prompts based on actual analysis
design_prompts = create_analysis_based_prompts(selected_suggestion, primary_style, main_colors, clothing_category)
design_images = []
design_choices = []
total_designs = 4
for i in range(total_designs):
try:
progress_val = 0.2 + (i / total_designs) * 0.7
progress(progress_val, desc=f"生成设计方案 {i+1}/{total_designs}...")
prompt = design_prompts[i % len(design_prompts)]
# Generate image with analysis-based parameters
image = model_manager.generate_image(
prompt=prompt,
negative_prompt="blurry, low quality, distorted, text, watermark, deformed, ugly",
width=512,
height=512,
num_inference_steps=25, # Optimized for Spaces
guidance_scale=7.5
)
if image:
design_images.append(image)
design_choices.append(f"{selected_suggestion} - 方案{i+1}")
logger.info(f"Generated design {i+1} based on {primary_style} style")
except Exception as e:
logger.error(f"Failed to generate design {i+1}: {e}")
# Create error placeholder
error_img = model_manager._create_placeholder_image(f"方案{i+1}\n生成中...")
design_images.append(error_img)
design_choices.append(f"{selected_suggestion} - 方案{i+1}")
progress(0.95, desc="完成设计生成")
logger.info(f"Generated {len(design_images)} designs based on analysis results")
return design_images, gr.Radio(
choices=design_choices,
value=design_choices[0] if design_choices else None
)
except Exception as e:
logger.error(f"Design generation error: {e}")
return [], gr.Radio(choices=[])
def create_analysis_based_prompts(suggestion: str, style: str, colors: List[str], category: str) -> List[str]:
"""Create prompts based on actual image analysis results (FIXED)"""
color_desc = ", ".join(colors[:2]) if colors else "neutral colors"
prompts = [
f"professional fashion design, {style} style, {category}, featuring {color_desc}, high quality, detailed, studio photography",
f"modern {category} design, {style} aesthetic, {color_desc} color palette, innovative cut, premium materials",
f"elegant {category}, {style} inspiration, {color_desc} tones, contemporary fashion, artistic design",
f"luxury {category} piece, {style} influence, {color_desc} color scheme, high-end fashion, detailed textures"
]
# Add style-specific enhancements
if "商务" in style:
prompts = [p + ", professional, office appropriate, tailored fit" for p in prompts]
elif "休闲" in style:
prompts = [p + ", comfortable, everyday wear, relaxed fit" for p in prompts]
elif "运动" in style:
prompts = [p + ", athletic, functional, performance fabric" for p in prompts]
elif "时尚" in style:
prompts = [p + ", trendy, runway inspired, fashion forward" for p in prompts]
return prompts
def generate_3d_fitting(selected_design, progress=gr.Progress()):
"""Enhanced 3D fitting generation with better precision (FIXED)"""
try:
if not selected_design:
return None
logger.info(f"Generating enhanced 3D fitting for: {selected_design}")
progress(0.1, desc="准备3D试穿环境...")
# Create pose reference for ControlNet
pose_image = model_manager.create_pose_reference()
progress(0.3, desc="生成人体模型...")
# Extract design details from selection
design_prompt = f"wearing {selected_design}, fashion model, full body view"
progress(0.6, desc="应用服装设计...")
# Generate 3D fitting with enhanced pipeline
image = model_manager.generate_3d_fitting(
prompt=design_prompt,
pose_image=pose_image
)
progress(0.9, desc="完成3D渲染")
logger.info("Enhanced 3D fitting generated successfully")
return image
except Exception as e:
logger.error(f"3D fitting generation error: {e}")
return model_manager._create_placeholder_image("3D试穿\n生成失败")
# Additional utility functions (keeping the existing ones but fixing issues)
def analyze_fashion_trends(style: str, colors_info: List[Dict]) -> Dict:
"""Analyze current fashion trends"""
trends = {
"2024流行趋势": [],
"颜色趋势": [],
"材质趋势": [],
"设计元素": []
}
# Style-based trend analysis
if "商务" in style:
trends["2024流行趋势"].extend(["可持续面料", "多功能设计", "性别中性"])
trends["材质趋势"].extend(["有机棉", "再生纤维", "功能性面料"])
elif "休闲" in style:
trends["2024流行趋势"].extend(["舒适至上", "居家办公风", "运动休闲"])
trends["材质趋势"].extend(["弹性面料", "透气材质", "抗菌纤维"])
# Color trend analysis
dominant_colors = [color["name"] for color in colors_info[:2]]
if any("绿" in color for color in dominant_colors):
trends["颜色趋势"].append("生态绿色")
if any("蓝" in color for color in dominant_colors):
trends["颜色趋势"].append("经典蓝调")
trends["设计元素"] = ["极简线条", "功能细节", "可调节设计", "层次搭配"]
return trends
def get_enhanced_suitable_scenes(style_type: str, colors_info: List[Dict]) -> List[str]:
"""Enhanced scene recommendations based on style and colors"""
base_scenes = {
"商务正装": ["高级办公环境", "商务会议", "正式谈判", "企业活动", "专业面试"],
"休闲风": ["咖啡厅约会", "周末购物", "朋友聚会", "公园散步", "居家办公"],
"运动风": ["健身房训练", "户外跑步", "瑜伽课程", "运动赛事", "休闲运动"],
"时尚潮流": ["时装周活动", "艺术展开幕", "网红打卡点", "时尚派对", "创意工作环境"],
"复古风": ["文艺咖啡厅", "古典音乐会", "复古主题派对", "艺术博物馆", "文化活动"],
"街头风": ["音乐节现场", "街头涂鸦区", "潮流市集", "滑板公园", "创意园区"],
"优雅风": ["高端晚宴", "歌剧院", "五星酒店", "慈善晚会", "高级社交场合"]
}
scenes = base_scenes.get(style_type, ["日常生活", "社交活动", "休闲娱乐"])
# Add season-specific scenes based on colors
seasons = [color.get("season_match", "") for color in colors_info]
if "春季" in seasons:
scenes.extend(["春日踏青", "花园聚会"])
elif "夏季" in seasons:
scenes.extend(["海滩度假", "夏日音乐节"])
elif "秋季" in seasons:
scenes.extend(["秋日登山", "文艺展览"])
elif "冬季" in seasons:
scenes.extend(["冬日聚会", "温暖室内活动"])
return list(set(scenes))
def infer_clothing_category(caption: str) -> str:
"""Enhanced clothing category detection"""
caption_lower = caption.lower()
categories = {
"连衣裙": ["dress", "gown", "frock", "sundress", "连衣裙", "礼服", "长裙"],
"上衣": ["shirt", "blouse", "top", "sweater", "hoodie", "cardigan", "衬衫", "上衣", "毛衣"],
"外套": ["jacket", "coat", "blazer", "cardigan", "outerwear", "外套", "大衣", "夹克"],
"下装": ["pants", "jeans", "trousers", "skirt", "shorts", "裤子", "短裤", "裙子"],
"套装": ["suit", "ensemble", "matching set", "套装", "西装", "套服"],
"配饰": ["accessories", "hat", "bag", "shoes", "jewelry", "scarf", "帽子", "包", "鞋子", "配饰"],
"内衣": ["underwear", "lingerie", "bra", "内衣", "文胸"],
"运动装": ["sportswear", "athletic wear", "gym clothes", "运动装", "健身服"],
"睡衣": ["pajamas", "nightwear", "sleepwear", "睡衣", "家居服"]
}
# Calculate scores for each category
category_scores = {}
for category, keywords in categories.items():
score = sum(1 for keyword in keywords if keyword in caption_lower)
if score > 0:
category_scores[category] = score
if category_scores:
return max(category_scores.keys(), key=lambda x: category_scores[x])
return "时尚单品"
def generate_enhanced_suggestions(analysis_result: Dict) -> Dict:
"""Generate enhanced personalized suggestions based on analysis"""
primary_style = analysis_result.get("主要风格", "休闲风")
colors = analysis_result.get("主要颜色", ["蓝色"])
trend_analysis = analysis_result.get("流行趋势匹配", {})
suggestions = {}
# Base style suggestions
style_variations = {
"商务正装": {
"现代商务精英": f"融合{primary_style}与现代设计,主打{colors[0]}色调",
"轻奢商务": f"{primary_style}加入轻奢元素,提升品质感",
"可持续商务": f"环保理念的{primary_style},使用可持续面料",
"多功能商务": f"一衣多穿的{primary_style}设计"
},
"休闲风": {
"都市休闲": f"现代都市感的{primary_style},主色调{colors[0]}",
"舒适至上": f"极致舒适的{primary_style}体验",
"运动休闲融合": f"{primary_style}与运动元素完美结合",
"艺术休闲": f"加入艺术元素的{primary_style}"
},
"运动风": {
"专业运动": f"专业级{primary_style},注重功能性",
"时尚运动": f"{primary_style}融入时尚潮流元素",
"户外探险": f"户外功能性{primary_style}设计",
"瑜伽冥想": f"身心和谐的{primary_style}体验"
}
}
# Get style-specific suggestions
if primary_style in style_variations:
suggestions.update(style_variations[primary_style])
else:
# Generic suggestions for other styles
suggestions = {
f"经典{primary_style}": f"保持{primary_style}的经典魅力",
f"现代{primary_style}": f"{primary_style}与现代元素结合",
f"个性{primary_style}": f"展现独特个性的{primary_style}",
f"趋势{primary_style}": f"紧跟2024流行趋势的{primary_style}"
}
# Add trend-based suggestions
if "可持续面料" in trend_analysis.get("2024流行趋势", []):
suggestions["环保时尚"] = f"可持续发展理念的{primary_style}设计"
return suggestions
def create_gradio_interface():
"""Create Gradio interface optimized for Hugging Face Spaces"""
# Custom CSS optimized for Spaces
custom_css = """
.gradio-container {
font-family: 'Inter', sans-serif;
max-width: 1200px;
margin: 0 auto;
}
.main-header {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
border-radius: 15px;
margin-bottom: 20px;
box-shadow: 0 4px 15px rgba(0,0,0,0.1);
}
.status-box {
background: #f8f9fa;
border: 1px solid #dee2e6;
border-radius: 8px;
padding: 10px;
margin: 5px 0;
}
"""
with gr.Blocks(
title="AI时尚设计师 Pro - Hugging Face Spaces",
theme=gr.themes.Soft(),
css=custom_css
) as demo:
# Header optimized for Spaces
gr.HTML("""
<div class="main-header">
<h1>🎨 AI时尚设计师 Pro</h1>
<p>基于 Hugging Face 的专业时尚设计平台</p>
<p>✨ BLIP图像理解 + Stable Diffusion设计生成 + ControlNet 3D试穿</p>
</div>
""")
# System status for debugging
with gr.Row():
with gr.Column():
system_info = gr.HTML(f"""
<div class="status-box">
<strong>系统状态:</strong>
CUDA可用: {torch.cuda.is_available()} |
设备: {'GPU' if torch.cuda.is_available() else 'CPU'} |
模型状态: {'✅ 就绪' if model_manager._models_ready else '⏳ 加载中'}
</div>
""")
# Main interface
with gr.Row():
# Left panel - Image upload and analysis
with gr.Column(scale=1):
gr.Markdown("## 📸 图片分析")
image_input = gr.Image(
type="filepath",
label="上传时尚图片",
height=300
)
with gr.Row():
analyze_btn = gr.Button(
"🔍 AI智能分析",
variant="primary",
size="lg"
)
# Analysis stats
with gr.Row():
with gr.Column():
analysis_time = gr.Textbox(label="分析耗时", interactive=False, container=False)
with gr.Column():
model_status = gr.Textbox(label="AI状态", interactive=False, container=False)
# Right panel - Results
with gr.Column(scale=2):
gr.Markdown("## 📊 分析结果")
with gr.Tabs():
with gr.Tab("🔬 详细分析"):
analysis_output = gr.JSON(label="完整分析报告")
with gr.Tab("🎨 色彩分析"):
color_analysis = gr.DataFrame(
headers=["颜色名称", "RGB值", "十六进制", "亮度", "饱和度", "季节匹配"],
label="色彩详细信息"
)
# Design suggestions section
gr.Markdown("## 💡 个性化设计建议")
with gr.Row():
with gr.Column(scale=2):
suggestions_output = gr.JSON(label="基于AI分析的设计建议")
with gr.Column(scale=1):
suggestion_choice = gr.Radio(label="选择设计方向", interactive=True)
generate_designs_btn = gr.Button("🚀 生成设计方案", variant="primary")
# Design results
with gr.Tabs():
with gr.Tab("🎯 设计方案"):
designs_gallery = gr.Gallery(
label="AI生成的设计方案 (基于图像分析)",
columns=2,
rows=2,
height=400
)
design_choice = gr.Radio(label="选择方案进行3D试穿", interactive=True)
generate_3d_btn = gr.Button("👤 生成3D试穿", variant="primary")
with gr.Tab("👥 3D试穿效果"):
fitting_result = gr.Image(label="ControlNet增强3D试穿效果", height=500)
with gr.Row():
gr.Markdown("""
**3D试穿技术说明:**
- 使用 ControlNet + OpenPose 实现精确人体建模
- 基于图像分析结果生成逼真试穿效果
- 支持全身服装展示和细节呈现
""")
# Performance controls for Spaces
with gr.Accordion("⚙️ 系统控制", open=False):
with gr.Row():
cleanup_btn = gr.Button("🧹 清理GPU内存", variant="secondary")
reload_models_btn = gr.Button("🔄 重新加载模型", variant="secondary")
memory_status = gr.Textbox(label="内存状态", interactive=False)
# Examples for quick testing
gr.Markdown("## 🌟 快速体验")
examples = [
["examples/business_suit.jpg"] if os.path.exists("examples/business_suit.jpg") else None,
["examples/casual_wear.jpg"] if os.path.exists("examples/casual_wear.jpg") else None,
["examples/sport_outfit.jpg"] if os.path.exists("examples/sport_outfit.jpg") else None,
]
examples = [ex for ex in examples if ex is not None]
if examples:
gr.Examples(
examples=examples,
inputs=image_input,
label="点击体验示例"
)
# Hidden state for passing data between functions
analysis_state = gr.State({})
# Event handlers with proper data flow
def enhanced_analysis(image_path):
"""Enhanced analysis with better error handling"""
try:
result = upload_and_analyze(image_path)
analysis_result, suggestions, suggestion_radio = result
# Update analysis state
analysis_state.value = analysis_result
# Prepare color analysis table
color_data = []
if "详细色彩分析" in analysis_result:
for color_info in analysis_result["详细色彩分析"]:
color_data.append([
color_info.get("name", "未知"),
str(color_info.get("rgb", [0, 0, 0])),
color_info.get("hex", "#000000"),
f"{color_info.get('brightness', 0):.2f}",
f"{color_info.get('saturation', 0):.2f}",
color_info.get("season_match", "未知")
])
# Extract timing and status
time_taken = analysis_result.get("分析耗时", "未知")
ai_status = analysis_result.get("AI模型状态", "未知")
return (
analysis_result,
suggestions,
suggestion_radio,
color_data,
time_taken,
ai_status,
analysis_result # Update state
)
except Exception as e:
logger.error(f"Enhanced analysis failed: {e}")
error_result = {"错误": f"分析失败: {str(e)}"}
return error_result, {}, gr.Radio(choices=[]), [], "错误", "❌ 分析失败", {}
# Bind main analysis event
analyze_btn.click(
fn=enhanced_analysis,
inputs=[image_input],
outputs=[
analysis_output,
suggestions_output,
suggestion_choice,
color_analysis,
analysis_time,
model_status,
analysis_state
]
)
# Design generation with analysis integration
generate_designs_btn.click(
fn=lambda suggestion, analysis: generate_designs(suggestion, analysis),
inputs=[suggestion_choice, analysis_state],
outputs=[designs_gallery, design_choice]
)
# 3D fitting generation
generate_3d_btn.click(
fn=generate_3d_fitting,
inputs=[design_choice],
outputs=[fitting_result]
)
# System controls
cleanup_btn.click(
fn=model_manager.cleanup,
inputs=[],
outputs=[memory_status]
)
def reload_models():
try:
model_manager.__init__() # Reinitialize
return "✅ 模型重新加载完成"
except Exception as e:
return f"❌ 重新加载失败: {str(e)}"
reload_models_btn.click(
fn=reload_models,
inputs=[],
outputs=[memory_status]
)
# Footer with Spaces-specific information
gr.Markdown("""
---
### 🚀 Hugging Face Spaces 优化版
**技术栈:**
- 🔤 BLIP: 图像理解与描述生成
- 🎨 Stable Diffusion 1.5: 设计方案生成
- 🏃 ControlNet + OpenPose: 精确3D试穿
- 📊 scikit-learn: 智能色彩分析
**性能优化:**
- ⚡ 针对 Spaces GPU/CPU 环境优化
- 🧠 智能内存管理,避免OOM
- 🔄 自动模型清理和重载
- 📱 响应式界面设计
> 💡 **提示**: 首次运行需要下载模型,请稍等片刻。生成过程中可能需要1-3分钟。
""")
return demo
def main():
"""Main function optimized for Hugging Face Spaces"""
try:
# Create necessary directories
os.makedirs("examples", exist_ok=True)
logger.info("Starting AI Fashion Designer Pro on Hugging Face Spaces...")
logger.info(f"CUDA Available: {torch.cuda.is_available()}")
logger.info(f"PyTorch Version: {torch.__version__}")
# Create and launch interface
demo = create_gradio_interface()
# Configure for Spaces
demo.queue(
concurrency_count=1, # Limited concurrency for Spaces
max_size=5,
api_open=False
)
# Launch with Spaces-optimized settings
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # Spaces handles sharing
show_error=True,
quiet=False,
favicon_path=None
)
except Exception as e:
logger.error(f"Application startup failed: {e}")
print(f"Error: {e}")
raise
if __name__ == "__main__":
main()