parthraninga commited on
Commit
c28eb35
Β·
verified Β·
1 Parent(s): b39b691

Upload app_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app_hf.py +176 -0
app_hf.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import ContentClassifierInference
3
+ import os
4
+ import json
5
+ import logging
6
+
7
+ # Setup logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Initialize the model
12
+ try:
13
+ model = ContentClassifierInference()
14
+ model_initialized = True
15
+ logger.info("Model initialized successfully")
16
+ except Exception as e:
17
+ logger.error(f"Error initializing model: {e}")
18
+ model_initialized = False
19
+
20
+ def classify_text(text):
21
+ """Classify single text input"""
22
+ if not model_initialized:
23
+ return json.dumps({"error": "Model initialization failed"}, indent=2)
24
+
25
+ if not text or not text.strip():
26
+ return json.dumps({"error": "Please provide valid text input"}, indent=2)
27
+
28
+ try:
29
+ result = model.predict(text.strip())
30
+ logger.info(f"Processed text classification: {result['threat_prediction']}")
31
+ return json.dumps(result, indent=2)
32
+ except Exception as e:
33
+ logger.error(f"Classification error: {e}")
34
+ return json.dumps({"error": str(e)}, indent=2)
35
+
36
+ def classify_batch(text):
37
+ """Classify batch of texts (one per line)"""
38
+ if not model_initialized:
39
+ return json.dumps({"error": "Model initialization failed"}, indent=2)
40
+
41
+ if not text or not text.strip():
42
+ return json.dumps({"error": "Please provide valid text input"}, indent=2)
43
+
44
+ try:
45
+ # Split by newlines for batch processing
46
+ texts = [t.strip() for t in text.split("\n") if t.strip()]
47
+ if not texts:
48
+ return json.dumps({"error": "No valid texts provided"}, indent=2)
49
+
50
+ if len(texts) > 10: # Limit batch size
51
+ return json.dumps({"error": "Batch size limited to 10 texts"}, indent=2)
52
+
53
+ results = model.predict_batch(texts)
54
+ logger.info(f"Processed batch of {len(texts)} texts")
55
+ return json.dumps(results, indent=2)
56
+ except Exception as e:
57
+ logger.error(f"Batch classification error: {e}")
58
+ return json.dumps({"error": str(e)}, indent=2)
59
+
60
+ # API function for external use
61
+ def predict_api(text):
62
+ """API endpoint for programmatic access"""
63
+ if isinstance(text, list):
64
+ return [json.loads(classify_text(t)) for t in text]
65
+ else:
66
+ return json.loads(classify_text(text))
67
+
68
+ # Create Gradio interface
69
+ with gr.Blocks(title="Content Classifier", theme=gr.themes.Soft()) as demo:
70
+ gr.Markdown("# πŸ” Content Classifier")
71
+ gr.Markdown("""
72
+ This tool classifies content as either **safe** or **unsafe** using an ONNX model.
73
+ Perfect for content moderation, safety checks, and automated text analysis.
74
+ """)
75
+
76
+ # Status indicator
77
+ if model_initialized:
78
+ gr.Markdown("βœ… **Model Status**: Ready")
79
+ else:
80
+ gr.Markdown("❌ **Model Status**: Failed to initialize")
81
+
82
+ with gr.Tab("Single Text Classification"):
83
+ with gr.Row():
84
+ with gr.Column():
85
+ text_input = gr.Textbox(
86
+ label="Enter text to classify",
87
+ lines=5,
88
+ placeholder="Type or paste your text here...",
89
+ max_lines=10
90
+ )
91
+ classify_btn = gr.Button("πŸ” Classify", variant="primary", size="lg")
92
+
93
+ # Examples
94
+ gr.Examples(
95
+ examples=[
96
+ ["This is a normal, safe piece of content."],
97
+ ["Hello, how are you doing today?"],
98
+ ["Example text for content classification"]
99
+ ],
100
+ inputs=text_input,
101
+ label="Try these examples:"
102
+ )
103
+
104
+ with gr.Column():
105
+ result_output = gr.JSON(label="Classification Result", show_label=True)
106
+ classify_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)
107
+
108
+ with gr.Tab("Batch Processing"):
109
+ with gr.Row():
110
+ with gr.Column():
111
+ batch_input = gr.Textbox(
112
+ label="Enter multiple texts (one per line)",
113
+ lines=10,
114
+ placeholder="Text 1\nText 2\nText 3\n...(max 10 texts)",
115
+ max_lines=15
116
+ )
117
+ batch_btn = gr.Button("πŸ“‹ Process Batch", variant="primary", size="lg")
118
+
119
+ gr.Markdown("**Note**: Maximum 10 texts per batch")
120
+
121
+ with gr.Column():
122
+ batch_output = gr.JSON(label="Batch Classification Results", show_label=True)
123
+
124
+ batch_btn.click(fn=classify_batch, inputs=batch_input, outputs=batch_output)
125
+
126
+ with gr.Tab("API Documentation"):
127
+ gr.Markdown("""
128
+ ## πŸ”Œ API Usage
129
+
130
+ This Space can be used as an API endpoint for programmatic access.
131
+
132
+ ### Single Text Classification
133
+ ```python
134
+ import requests
135
+
136
+ url = "https://your-space-name.hf.space/predict"
137
+ response = requests.post(url, json={"text": "Your content to classify"})
138
+ result = response.json()
139
+ ```
140
+
141
+ ### Batch Processing
142
+ ```python
143
+ import requests
144
+
145
+ url = "https://your-space-name.hf.space/predict"
146
+ texts = ["Text 1", "Text 2", "Text 3"]
147
+ response = requests.post(url, json={"text": texts})
148
+ results = response.json()
149
+ ```
150
+
151
+ ### Response Format
152
+ ```json
153
+ {
154
+ "is_threat": false,
155
+ "final_confidence": 0.85,
156
+ "threat_prediction": "safe",
157
+ "onnx_prediction": {
158
+ "safe": 0.85,
159
+ "unsafe": 0.15
160
+ },
161
+ "models_used": ["onnx"],
162
+ "raw_predictions": {...}
163
+ }
164
+ ```
165
+
166
+ ### Using with curl
167
+ ```bash
168
+ curl -X POST https://your-space-name.hf.space/predict \\
169
+ -H "Content-Type: application/json" \\
170
+ -d '{"text": "Your content to classify"}'
171
+ ```
172
+ """)
173
+
174
+ # Launch the app
175
+ if __name__ == "__main__":
176
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)