aka7774 commited on
Commit
a19411a
·
verified ·
1 Parent(s): de3b741

Update fn.py

Browse files
Files changed (1) hide show
  1. fn.py +2 -2
fn.py CHANGED
@@ -53,7 +53,7 @@ def check(image):
53
  prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
54
 
55
  # NSFW/SFW判定
56
- tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
57
  max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
58
  max_sfw_score = tag_confidences.get("general", 0)
59
 
@@ -73,7 +73,7 @@ def check(image):
73
  tag_index = i - 4
74
  if tag_index < len(general_tags):
75
  tag_name = general_tags[tag_index]
76
- general_tags_with_probs[tag_name] = p
77
  tags = sorted(general_tags_with_probs.items(), key = lambda x : x[1], reverse=True)
78
 
79
  return {
 
53
  prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
54
 
55
  # NSFW/SFW判定
56
+ tag_confidences = {tag: float(prob[i]) for i, tag in enumerate(rating_tags)}
57
  max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
58
  max_sfw_score = tag_confidences.get("general", 0)
59
 
 
73
  tag_index = i - 4
74
  if tag_index < len(general_tags):
75
  tag_name = general_tags[tag_index]
76
+ general_tags_with_probs[tag_name] = float(p)
77
  tags = sorted(general_tags_with_probs.items(), key = lambda x : x[1], reverse=True)
78
 
79
  return {