File size: 732 Bytes
09b1412
500d27e
09b1412
cbd0142
 
 
 
 
09b1412
338a141
cbd0142
13294ce
 
 
 
 
cbd0142
13294ce
cbd0142
09b1412
338a141
 
 
 
 
 
40a565e
500d27e
cbd0142
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gradio as gr
from transformers import pipeline

pipe = pipeline(
    task="text-classification",
    model="TostAI/nsfw-text-detection-large",
    return_all_scores=True
)

def predict(input_txt):
    predictions = pipe(input_txt)[0]  # Get first (and only) result
    label_map = {
        'LABEL_0': 'SAFE',
        'LABEL_1': 'QUESTIONABLE',
        'LABEL_2': 'UNSAFE'
    }
    # Convert list of dicts to a single dict
    result = {label_map[pred['label']]: pred['score'] for pred in predictions}
    return result

gradio_app = gr.Interface(
    predict,
    inputs=gr.Textbox(label="Input text"),
    outputs=gr.Label(label="Result"),
    title="NSFW Prediction",
)

if __name__ == "__main__":
    gradio_app.launch()