Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| from PIL import Image | |
| # 配置参数 | |
| labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"] | |
| theme_color = "#6C5B7B" # 主色调改为优雅的紫色 | |
| description = """<div style="padding: 20px; background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); border-radius: 10px;"> | |
| <h2 style="color: {color}; margin-bottom: 15px;">🎨 NSFW 图片分类器</h2> | |
| <p>该模型使用深度神经网络对图片内容进行分类,支持以下类别:</p> | |
| <ul style="list-style-type: circle; padding-left: 25px;"> | |
| <li><span style="color: #4B4453;">Drawings</span> - 艺术绘画作品</li> | |
| <li><span style="color: #845EC2;">Hentai</span> - 二次元成人内容</li> | |
| <li><span style="color: #008F7A;">Neutral</span> - 日常安全内容</li> | |
| <li><span style="color: #D65DB1;">Porn</span> - 露骨成人内容</li> | |
| <li><span style="color: #FF9671;">Sexy</span> - 性感但不露骨内容</li> | |
| </ul> | |
| <p style="margin-top: 15px;">🖼️ 请上传图片或点击下方示例体验</p> | |
| </div>""".format(color=theme_color) | |
| # 模型定义和预处理(保持不变) | |
| # ... [保持原有模型代码不变] ... | |
| # 高级 CSS 样式 | |
| advanced_css = f""" | |
| .gradio-container {{ | |
| background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); | |
| min-height: 100vh; | |
| }} | |
| .header-section {{ | |
| background: white; | |
| padding: 2rem; | |
| border-radius: 15px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.05); | |
| margin-bottom: 2rem; | |
| }} | |
| .result-card {{ | |
| background: white !important; | |
| padding: 1.5rem !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 2px 8px rgba(108,91,123,0.1) !important; | |
| }} | |
| .custom-button {{ | |
| background: {theme_color} !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 12px 28px !important; | |
| border-radius: 25px !important; | |
| transition: all 0.3s ease !important; | |
| }} | |
| .custom-button:hover {{ | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 12px rgba(108,91,123,0.3) !important; | |
| }} | |
| .upload-box {{ | |
| border: 2px dashed {theme_color} !important; | |
| border-radius: 15px !important; | |
| background: rgba(255,255,255,0.9) !important; | |
| }} | |
| .example-card {{ | |
| cursor: pointer; | |
| transition: all 0.3s ease; | |
| border-radius: 12px; | |
| overflow: hidden; | |
| }} | |
| .example-card:hover {{ | |
| transform: scale(1.02); | |
| box-shadow: 0 4px 12px rgba(108,91,123,0.2); | |
| }} | |
| .prob-bar {{ | |
| height: 8px; | |
| border-radius: 4px; | |
| background: linear-gradient(90deg, {theme_color} 0%, #C06C84 100%); | |
| }} | |
| """ | |
| # Define CNN model | |
| class Classifier(nn.Module): | |
| def __init__(self): | |
| super(Classifier, self).__init__() | |
| self.cnn_layers = resnet18(weights=ResNet18_Weights.DEFAULT) | |
| self.fc_layers = nn.Sequential( | |
| nn.Linear(1000, 512), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 5), | |
| ) | |
| def forward(self, x): | |
| x = self.cnn_layers(x) | |
| x = self.fc_layers(x) | |
| return x | |
| # Pre-process | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Load model | |
| model = Classifier() | |
| model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) | |
| model.eval() | |
| def predict(image_path): | |
| img = Image.open(image_path).convert("RGB") | |
| img = preprocess(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(model(img)[0], dim=0) | |
| result = {labels[i]: float(prediction[i]) for i in range(5)} | |
| return result | |
| with gr.Blocks(theme=gr.themes.Soft(), css=advanced_css) as demo: | |
| # 标题区 | |
| with gr.Column(elem_classes="header-section"): | |
| gr.Markdown("# 🎭 智能内容识别系统", elem_id="main-title") | |
| gr.HTML(description) | |
| # 主功能区 | |
| with gr.Row(): | |
| # 输入列 | |
| with gr.Column(scale=2): | |
| upload_box = gr.Image( | |
| type="filepath", | |
| label="📤 上传图片", | |
| elem_id="upload-box", | |
| elem_classes="upload-box", | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button( | |
| "✨ 开始分析", | |
| elem_classes="custom-button", | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button( | |
| "🔄 重新上传", | |
| variant="secondary", | |
| size="lg" | |
| ) | |
| # 输出列 | |
| with gr.Column(scale=1): | |
| with gr.Column(elem_classes="result-card"): | |
| gr.Markdown("### 🔍 分析结果") | |
| result_display = gr.Label( | |
| label="分类概率分布", | |
| num_top_classes=3, | |
| show_label=False | |
| ) | |
| gr.Markdown("**最高概率类别**: <span id='top-class'></span>", elem_id="dynamic-text") | |
| # 示例区 | |
| with gr.Column(): | |
| gr.Markdown("### 🖼️ 示例图片") | |
| examples = gr.Examples( | |
| examples=["./example/anime.jpg", "./example/real.jpg"], | |
| inputs=upload_box, | |
| examples_per_page=2, | |
| label="点击使用示例", | |
| elem_id="example-gallery" | |
| ) | |
| # 交互逻辑 | |
| clear_btn.click(fn=lambda: None, inputs=None, outputs=upload_box) | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=upload_box, | |
| outputs=result_display, | |
| api_name="predict" | |
| ) | |
| # 启动界面 | |
| demo.launch() |