aszgod commited on
Commit
75bd0f6
·
verified ·
1 Parent(s): 847020d

Delete content_analyser.py

Browse files
Files changed (1) hide show
  1. content_analyser.py +0 -144
content_analyser.py DELETED
@@ -1,144 +0,0 @@
1
- from functools import lru_cache
2
- from typing import List
3
-
4
- import numpy
5
- from tqdm import tqdm
6
-
7
- from facefusion import inference_manager, state_manager, wording
8
- from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
9
- from facefusion.filesystem import resolve_relative_path
10
- from facefusion.thread_helper import conditional_thread_semaphore
11
- from facefusion.types import Detection, DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, Score, VisionFrame
12
- from facefusion.vision import detect_video_fps, fit_frame, read_image, read_video_frame
13
-
14
- STREAM_COUNTER = 0
15
-
16
-
17
- @lru_cache(maxsize = None)
18
- def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
19
- return\
20
- {
21
- 'yolo_nsfw':
22
- {
23
- 'hashes':
24
- {
25
- 'content_analyser':
26
- {
27
- 'url': resolve_download_url('models-3.2.0', 'yolo_11m_nsfw.hash'),
28
- 'path': resolve_relative_path('../.assets/models/yolo_11m_nsfw.hash')
29
- }
30
- },
31
- 'sources':
32
- {
33
- 'content_analyser':
34
- {
35
- 'url': resolve_download_url('models-3.2.0', 'yolo_11m_nsfw.onnx'),
36
- 'path': resolve_relative_path('../.assets/models/yolo_11m_nsfw.onnx')
37
- }
38
- },
39
- 'size': (640, 640)
40
- }
41
- }
42
-
43
-
44
- def get_inference_pool() -> InferencePool:
45
- model_names = [ 'yolo_nsfw' ]
46
- model_source_set = get_model_options().get('sources')
47
-
48
- return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
49
-
50
-
51
- def clear_inference_pool() -> None:
52
- model_names = [ 'yolo_nsfw' ]
53
- inference_manager.clear_inference_pool(__name__, model_names)
54
-
55
-
56
- def get_model_options() -> ModelOptions:
57
- return create_static_model_set('full').get('yolo_nsfw')
58
-
59
-
60
- def pre_check() -> bool:
61
- model_hash_set = get_model_options().get('hashes')
62
- model_source_set = get_model_options().get('sources')
63
-
64
- return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
65
-
66
-
67
- def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool:
68
- global STREAM_COUNTER
69
-
70
- STREAM_COUNTER = STREAM_COUNTER + 1
71
- if STREAM_COUNTER % int(video_fps) == 0:
72
- return analyse_frame(vision_frame)
73
- return False
74
-
75
-
76
- def analyse_frame(vision_frame : VisionFrame) -> bool:
77
- nsfw_scores = detect_nsfw(vision_frame)
78
-
79
- return len(nsfw_scores) > 0
80
-
81
-
82
- @lru_cache(maxsize = None)
83
- def analyse_image(image_path : str) -> bool:
84
- vision_frame = read_image(image_path)
85
- return analyse_frame(vision_frame)
86
-
87
-
88
- @lru_cache(maxsize = None)
89
- def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int) -> bool:
90
- video_fps = detect_video_fps(video_path)
91
- frame_range = range(trim_frame_start, trim_frame_end)
92
- rate = 0.0
93
- total = 0
94
- counter = 0
95
-
96
- with tqdm(total = len(frame_range), desc = wording.get('analysing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
97
-
98
- for frame_number in frame_range:
99
- if frame_number % int(video_fps) == 0:
100
- vision_frame = read_video_frame(video_path, frame_number)
101
- total += 1
102
- if analyse_frame(vision_frame):
103
- counter += 1
104
- if counter > 0 and total > 0:
105
- rate = counter / total * 100
106
- progress.set_postfix(rate = rate)
107
- progress.update()
108
-
109
- return rate > 10.0
110
-
111
-
112
- def detect_nsfw(vision_frame : VisionFrame) -> List[Score]:
113
- nsfw_scores = []
114
- model_size = get_model_options().get('size')
115
- temp_vision_frame = fit_frame(vision_frame, model_size)
116
- detect_vision_frame = prepare_detect_frame(temp_vision_frame)
117
- detection = forward(detect_vision_frame)
118
- detection = numpy.squeeze(detection).T
119
- nsfw_scores_raw = numpy.amax(detection[:, 4:], axis = 1)
120
- keep_indices = numpy.where(nsfw_scores_raw > 0.2)[0]
121
-
122
- if numpy.any(keep_indices):
123
- nsfw_scores_raw = nsfw_scores_raw[keep_indices]
124
- nsfw_scores = nsfw_scores_raw.ravel().tolist()
125
-
126
- return nsfw_scores
127
-
128
-
129
- def forward(vision_frame : VisionFrame) -> Detection:
130
- content_analyser = get_inference_pool().get('content_analyser')
131
-
132
- with conditional_thread_semaphore():
133
- detection = content_analyser.run(None,
134
- {
135
- 'input': vision_frame
136
- })
137
-
138
- return detection
139
-
140
-
141
- def prepare_detect_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
142
- detect_vision_frame = temp_vision_frame / 255.0
143
- detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32)
144
- return detect_vision_frame