aka7774 commited on
Commit
b842808
·
verified ·
1 Parent(s): a8cb303

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. fn.py +17 -8
  3. main.py +1 -1
app.py CHANGED
@@ -4,10 +4,10 @@ import gradio as gr
4
  fn.load_model()
5
 
6
  with gr.Blocks() as demo:
7
- title = gr.Markdown('# Safety Checker')
8
  with gr.Row():
9
  src_image = gr.Image(label="Source", sources="upload", interactive=True, type="pil")
10
- result = gr.Textbox(label="Has NSFW", interactive=False)
11
 
12
  src_image.change(
13
  fn=fn.check,
 
4
  fn.load_model()
5
 
6
  with gr.Blocks() as demo:
7
+ title = gr.Markdown('# Safety Checker 2')
8
  with gr.Row():
9
  src_image = gr.Image(label="Source", sources="upload", interactive=True, type="pil")
10
+ result = gr.Textbox(label="Result", interactive=False)
11
 
12
  src_image.change(
13
  fn=fn.check,
fn.py CHANGED
@@ -47,7 +47,7 @@ def check(image):
47
  image = preprocess_image(image)
48
  except Exception as e:
49
  print(f"画像を読み込めません: {image_path}, エラー: {e}")
50
- return
51
 
52
  img = np.array([image])
53
  prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
@@ -57,11 +57,9 @@ def check(image):
57
  max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
58
  max_sfw_score = tag_confidences.get("general", 0)
59
 
60
- return max_nsfw_score > max_sfw_score
61
-
62
  # 版権キャラクターの可能性を評価
63
  character_tags_with_probs = []
64
- for i, p in enumerate(prob[4:]):
65
  if p >= thresh and i >= len(general_tags):
66
  tag_index = i - len(general_tags)
67
  if tag_index < len(character_tags):
@@ -69,10 +67,21 @@ def check(image):
69
  prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
70
  character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
71
 
72
- if character_tags_with_probs:
73
- print(f"版権キャラクター: {character_tags_with_probs}の可能性があります")
74
- else:
75
- print("版権キャラクターの可能性が低いと思われます")
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def load_model(MODEL_ID = "SmilingWolf/wd-vit-tagger-v3"):
78
  global ort_sess, rating_tags, character_tags, general_tags
 
47
  image = preprocess_image(image)
48
  except Exception as e:
49
  print(f"画像を読み込めません: {image_path}, エラー: {e}")
50
+ return None
51
 
52
  img = np.array([image])
53
  prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
 
57
  max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
58
  max_sfw_score = tag_confidences.get("general", 0)
59
 
 
 
60
  # 版権キャラクターの可能性を評価
61
  character_tags_with_probs = []
62
+ for i, p in enumerate(prob[4+len(general_tags):]):
63
  if p >= thresh and i >= len(general_tags):
64
  tag_index = i - len(general_tags)
65
  if tag_index < len(character_tags):
 
67
  prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
68
  character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
69
 
70
+ general_tags_with_probs = {}
71
+ for i, p in enumerate(prob[4:4+len(general_tags)]):
72
+ if p >= thresh:
73
+ tag_index = i - 4
74
+ if tag_index < len(general_tags):
75
+ tag_name = general_tags[tag_index]
76
+ general_tags_with_probs[tag_name] = p
77
+ tags = sorted(general_tags_with_probs.items(), key = lambda x : x[1], reverse=True)
78
+
79
+ return {
80
+ 'has_nsfw': max_nsfw_score > max_sfw_score,
81
+ 'tag_confidences': tag_confidences,
82
+ 'character_tags_with_probs': character_tags_with_probs,
83
+ 'tags': tags,
84
+ }
85
 
86
  def load_model(MODEL_ID = "SmilingWolf/wd-vit-tagger-v3"):
87
  global ort_sess, rating_tags, character_tags, general_tags
main.py CHANGED
@@ -37,5 +37,5 @@ async def check_image(file: UploadFile = Form(...)):
37
 
38
  result = fn.check(Image.open(file_stream))
39
 
40
- return {"has_nsfw": result}
41
 
 
37
 
38
  result = fn.check(Image.open(file_stream))
39
 
40
+ return result
41