Spaces:
Build error
Build error
Upload 3 files
Browse files
app.py
CHANGED
|
@@ -4,10 +4,10 @@ import gradio as gr
|
|
| 4 |
fn.load_model()
|
| 5 |
|
| 6 |
with gr.Blocks() as demo:
|
| 7 |
-
title = gr.Markdown('# Safety Checker')
|
| 8 |
with gr.Row():
|
| 9 |
src_image = gr.Image(label="Source", sources="upload", interactive=True, type="pil")
|
| 10 |
-
result = gr.Textbox(label="
|
| 11 |
|
| 12 |
src_image.change(
|
| 13 |
fn=fn.check,
|
|
|
|
| 4 |
fn.load_model()
|
| 5 |
|
| 6 |
with gr.Blocks() as demo:
|
| 7 |
+
title = gr.Markdown('# Safety Checker 2')
|
| 8 |
with gr.Row():
|
| 9 |
src_image = gr.Image(label="Source", sources="upload", interactive=True, type="pil")
|
| 10 |
+
result = gr.Textbox(label="Result", interactive=False)
|
| 11 |
|
| 12 |
src_image.change(
|
| 13 |
fn=fn.check,
|
fn.py
CHANGED
|
@@ -47,7 +47,7 @@ def check(image):
|
|
| 47 |
image = preprocess_image(image)
|
| 48 |
except Exception as e:
|
| 49 |
print(f"画像を読み込めません: {image_path}, エラー: {e}")
|
| 50 |
-
return
|
| 51 |
|
| 52 |
img = np.array([image])
|
| 53 |
prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
|
|
@@ -57,11 +57,9 @@ def check(image):
|
|
| 57 |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
|
| 58 |
max_sfw_score = tag_confidences.get("general", 0)
|
| 59 |
|
| 60 |
-
return max_nsfw_score > max_sfw_score
|
| 61 |
-
|
| 62 |
# 版権キャラクターの可能性を評価
|
| 63 |
character_tags_with_probs = []
|
| 64 |
-
for i, p in enumerate(prob[4:]):
|
| 65 |
if p >= thresh and i >= len(general_tags):
|
| 66 |
tag_index = i - len(general_tags)
|
| 67 |
if tag_index < len(character_tags):
|
|
@@ -69,10 +67,21 @@ def check(image):
|
|
| 69 |
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
|
| 70 |
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
def load_model(MODEL_ID = "SmilingWolf/wd-vit-tagger-v3"):
|
| 78 |
global ort_sess, rating_tags, character_tags, general_tags
|
|
|
|
| 47 |
image = preprocess_image(image)
|
| 48 |
except Exception as e:
|
| 49 |
print(f"画像を読み込めません: {image_path}, エラー: {e}")
|
| 50 |
+
return None
|
| 51 |
|
| 52 |
img = np.array([image])
|
| 53 |
prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
|
|
|
|
| 57 |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
|
| 58 |
max_sfw_score = tag_confidences.get("general", 0)
|
| 59 |
|
|
|
|
|
|
|
| 60 |
# 版権キャラクターの可能性を評価
|
| 61 |
character_tags_with_probs = []
|
| 62 |
+
for i, p in enumerate(prob[4+len(general_tags):]):
|
| 63 |
if p >= thresh and i >= len(general_tags):
|
| 64 |
tag_index = i - len(general_tags)
|
| 65 |
if tag_index < len(character_tags):
|
|
|
|
| 67 |
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
|
| 68 |
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
|
| 69 |
|
| 70 |
+
general_tags_with_probs = {}
|
| 71 |
+
for i, p in enumerate(prob[4:4+len(general_tags)]):
|
| 72 |
+
if p >= thresh:
|
| 73 |
+
tag_index = i - 4
|
| 74 |
+
if tag_index < len(general_tags):
|
| 75 |
+
tag_name = general_tags[tag_index]
|
| 76 |
+
general_tags_with_probs[tag_name] = p
|
| 77 |
+
tags = sorted(general_tags_with_probs.items(), key = lambda x : x[1], reverse=True)
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
'has_nsfw': max_nsfw_score > max_sfw_score,
|
| 81 |
+
'tag_confidences': tag_confidences,
|
| 82 |
+
'character_tags_with_probs': character_tags_with_probs,
|
| 83 |
+
'tags': tags,
|
| 84 |
+
}
|
| 85 |
|
| 86 |
def load_model(MODEL_ID = "SmilingWolf/wd-vit-tagger-v3"):
|
| 87 |
global ort_sess, rating_tags, character_tags, general_tags
|
main.py
CHANGED
|
@@ -37,5 +37,5 @@ async def check_image(file: UploadFile = Form(...)):
|
|
| 37 |
|
| 38 |
result = fn.check(Image.open(file_stream))
|
| 39 |
|
| 40 |
-
return
|
| 41 |
|
|
|
|
| 37 |
|
| 38 |
result = fn.check(Image.open(file_stream))
|
| 39 |
|
| 40 |
+
return result
|
| 41 |
|