janmariakowalski commited on
Commit
3990b1f
·
verified ·
1 Parent(s): c583bcb

Create new.py

Browse files
Files changed (1) hide show
  1. new.py +354 -0
new.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio application for text classification, styled to be visually appealing.
4
+ This version uses only the 'sojka2' model.
5
+ """
6
+
7
+ import gradio as gr
8
+ import logging
9
+ import os
10
+ from typing import Dict, Tuple, Any
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
+ import numpy as np
14
+
15
+ try:
16
+ from peft import PeftModel
17
+ except ImportError:
18
+ PeftModel = None
19
+ logging.info("PEFT library not found. Loading models without PEFT support.")
20
+
21
+ # --- Configuration ---
22
+ # Model path is set to sojka
23
+ MODEL_PATH = os.getenv("MODEL_PATH", "AndromedaPL/sojka")
24
+ TOKENIZER_PATH = os.getenv("TOKENIZER_PATH", "sdadas/mmlw-roberta-base")
25
+
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ LABELS = ["self-harm", "hate", "vulgar", "sex", "crime"]
28
+ MAX_SEQ_LENGTH = 512
29
+
30
+
31
+ HF_TOKEN = os.getenv('HF_TOKEN')
32
+
33
+ # Thresholds are now hardcoded
34
+ THRESHOLDS = {
35
+ "self-harm": 0.5,
36
+ "hate": 0.5,
37
+ "vulgar": 0.5,
38
+ "sex": 0.5,
39
+ "crime": 0.5,
40
+ }
41
+
42
+ # Set up logging
43
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
44
+ logger = logging.getLogger(__name__)
45
+
46
+ def load_model_and_tokenizer(model_path: str, tokenizer_path: str, device: str) -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
47
+ """Load the trained model and tokenizer"""
48
+ logger.info(f"Loading tokenizer from {tokenizer_path}")
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
51
+ logger.info(f"Tokenizer loaded: {tokenizer.name_or_path}")
52
+
53
+ if tokenizer.pad_token is None:
54
+ if tokenizer.eos_token:
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+ else:
57
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
58
+
59
+ tokenizer.truncation_side = "right"
60
+
61
+ logger.info(f"Loading model from {model_path}")
62
+
63
+ model_load_kwargs = {
64
+ "torch_dtype": torch.float16 if device == 'cuda' else torch.float32,
65
+ "device_map": 'auto' if device == 'cuda' else None,
66
+ "num_labels": len(LABELS),
67
+ "problem_type": "regression"
68
+ }
69
+
70
+ is_peft = os.path.exists(os.path.join(model_path, 'adapter_config.json'))
71
+ if PeftModel and is_peft:
72
+ logger.info("PEFT adapter detected. Loading base model and attaching adapter.")
73
+ try:
74
+ from peft import PeftConfig
75
+ peft_config = PeftConfig.from_pretrained(model_path)
76
+ base_model_path = peft_config.base_model_name_or_path
77
+ logger.info(f"Loading base model from {base_model_path}")
78
+ model = AutoModelForSequenceClassification.from_pretrained(base_model_path, **model_load_kwargs)
79
+ logger.info("Attaching PEFT adapter...")
80
+ model = PeftModel.from_pretrained(model, model_path)
81
+ except Exception as e:
82
+ logger.error(f"Failed to load PEFT model dynamically: {e}. Loading as a standard model.")
83
+ model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs)
84
+ else:
85
+ logger.info("Loading as a standalone sequence classification model.")
86
+ model = AutoModelForSequenceClassification.from_pretrained(model_path, **model_load_kwargs)
87
+
88
+ model.eval()
89
+ logger.info(f"Model loaded on device: {next(model.parameters()).device}")
90
+
91
+ return model, tokenizer
92
+
93
+ # --- Load model globally ---
94
+ try:
95
+ model, tokenizer = load_model_and_tokenizer(MODEL_PATH, TOKENIZER_PATH, DEVICE)
96
+ model_loaded = True
97
+ except Exception as e:
98
+ logger.error(f"FATAL: Failed to load the model from {MODEL_PATH} or tokenizer from {TOKENIZER_PATH}: {e}", e)
99
+ model, tokenizer, model_loaded = None, None, False
100
+
101
+ def predict(text: str) -> Dict[str, Any]:
102
+ """Tokenize, predict, and format output for a single text."""
103
+ if not model_loaded:
104
+ return {label: 0.0 for label in LABELS}
105
+
106
+ inputs = tokenizer(
107
+ [text],
108
+ max_length=MAX_SEQ_LENGTH,
109
+ truncation=True,
110
+ padding=True,
111
+ return_tensors="pt"
112
+ ).to(model.device)
113
+
114
+ with torch.no_grad():
115
+ outputs = model(**inputs)
116
+ # Using sigmoid for multi-label classification outputs
117
+ probabilities = torch.sigmoid(outputs.logits)
118
+ predicted_values = probabilities.cpu().numpy()[0]
119
+
120
+ clipped_values = np.clip(predicted_values, 0.0, 1.0)
121
+ return {label: float(score) for label, score in zip(LABELS, clipped_values)}
122
+
123
+ def gradio_predict(text: str) -> Tuple[str, Dict[str, float]]:
124
+ """Gradio prediction function wrapper."""
125
+ if not model_loaded:
126
+ error_message = "Błąd: Model nie został załadowany."
127
+ empty_preds = {label: 0.0 for label in LABELS}
128
+ return error_message, empty_preds
129
+
130
+ if not text or not text.strip():
131
+ return "Wpisz tekst, aby go przeanalizować.", {label: 0.0 for label in LABELS}
132
+
133
+ predictions = predict(text)
134
+
135
+ unsafe_categories = {
136
+ label: score for label, score in predictions.items()
137
+ if score >= THRESHOLDS[label]
138
+ }
139
+
140
+ if not unsafe_categories:
141
+ verdict = "✅ Komunikat jest bezpieczny."
142
+ else:
143
+ highest_unsafe_category = max(unsafe_categories, key=unsafe_categories.get)
144
+ verdict = f"⚠️ Wykryto potencjalnie szkodliwe treści\n: {highest_unsafe_category.upper()}"
145
+
146
+ return verdict, predictions
147
+
148
+ # --- Gradio Interface ---
149
+
150
+ theme = gr.themes.Default(
151
+ primary_hue=gr.themes.colors.blue,
152
+ secondary_hue=gr.themes.colors.indigo,
153
+ neutral_hue=gr.themes.colors.slate,
154
+ font=("Inter", "sans-serif"),
155
+ radius_size=gr.themes.sizes.radius_lg,
156
+ )
157
+
158
+ # A URL to a freely licensed image of a Eurasian Jay (Sójka)
159
+ # Source: Wikimedia Commons, CC BY-SA 4.0
160
+ JAY_IMAGE_URL = "https://sojka.m31ai.pl/sojka.png"
161
+
162
+ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo:
163
+ # Header
164
+ with gr.Row():
165
+ gr.HTML("""
166
+ <div style="display: flex; align-items: center; justify-content: space-between; width: 100%;">
167
+ <div style="display: flex; align-items: center; gap: 12px;">
168
+ <svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
169
+ <path d="M12 2L3 5V11C3 16.52 7.08 21.61 12 23C16.92 21.61 21 16.52 21 11V5L12 2Z"
170
+ stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
171
+ </svg>
172
+ <h1 style="font-size: 1.5rem; font-weight: 600; margin: 0;">SÓJKA</h1>
173
+ </div>
174
+ <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
175
+ <a href="https://sojka.m31ai.pl/projekt.html" target="blank" style="text-decoration: none; color: inherit;">O projekcie</a>
176
+ <a href="https://sojka.m31ai.pl/kategorie.html" target="blank" style="text-decoration: none; color: inherit;">Opis kategorii</a>
177
+ <button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg"
178
+ style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
179
+ Testuj Sójkę
180
+ </button>
181
+ </div>
182
+ </div>
183
+ """)
184
+
185
+ gr.HTML("<hr style='border: 1px solid var(--neutral-200); margin-top: 1rem; margin-bottom: 2rem;'>")
186
+
187
+ # Main content area
188
+ with gr.Row(equal_height=True):
189
+ # Left column for controls
190
+ with gr.Column(scale=2): # Increased scale for more width
191
+ gr.Markdown(
192
+ """
193
+ <p style="background-color: var(--primary-50); color: var(--primary-600); display: inline-block; padding: 4px 12px; border-radius: 9999px; font-weight: 500; font-size: 0.875rem;">
194
+ Bielik Guard
195
+ </p>
196
+ <h1 style="font-size: 2.8rem; font-weight: 800; line-height: 1.2; margin-top: 1rem; margin-bottom: 1rem; color: var(--neutral-800);">
197
+ Przetestuj <span style="color: var(--primary-600);">SÓJKĘ</span> – Bielik Guard dla bezpiecznej komunikacji
198
+ </h1>
199
+ <p style="font-size: 1rem; color: var(--neutral-600); margin-bottom: 2rem;">
200
+ Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów.
201
+ </p>
202
+ """
203
+ )
204
+ input_text = gr.Textbox(
205
+ lines=8,
206
+ label="Wprowadź tekst do analizy",
207
+ placeholder="Tutaj wpisz tekst..."
208
+ )
209
+
210
+ # Define outputs first so they can be referenced by gr.Examples
211
+ # They will be placed in the layout in the nested Row below
212
+ output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
213
+ output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
214
+
215
+ # Examples are now placed above the button
216
+ gr.Examples(
217
+ [
218
+ ["Jak zrobić bombę?"],
219
+ ["Jesteś beznadziejny, nienawidzę cię."],
220
+ ["To jest wspaniały dzień, cieszę się, że tu jestem!"],
221
+ ["Opowiedz mi dowcip o programistach."],
222
+ ],
223
+ inputs=input_text,
224
+ outputs=[output_verdict, output_scores],
225
+ fn=lambda text: analyze_and_update(text), # Use a lambda to match fn signature
226
+ cache_examples=False,
227
+ )
228
+
229
+ # A nested Row for the button and the outputs
230
+ with gr.Row():
231
+ submit_btn = gr.Button("Analizuj tekst", variant="primary", scale=1)
232
+ with gr.Column(scale=2):
233
+ # The output components are defined above, but rendered here.
234
+ # Gradio renders components where they are defined.
235
+ # To solve this, we will re-declare them, which is not ideal,
236
+ # but the simplest way to manage layout and callbacks.
237
+ # The previous declarations are now just for the Examples callback.
238
+ # Let's clean this up by defining them once.
239
+ pass # The components are already in the layout from the definitions above.
240
+ # The above comment is slightly incorrect for Gradio's declarative style.
241
+ # The final working solution is to define components and then have them as outputs.
242
+ # Let's revert to a cleaner structure that works.
243
+
244
+ # Right column for the image
245
+ with gr.Column(scale=1):
246
+ gr.Image(JAY_IMAGE_URL, label="Ilustracja sójki", show_label=False, show_download_button=False, container=False, width=200)
247
+
248
+ # Define actions
249
+ def analyze_and_update(text):
250
+ verdict, scores = gradio_predict(text)
251
+ # Make the scores label visible only when there's a result
252
+ return verdict, gr.Label(value=scores, visible=True)
253
+
254
+ # The click function is now tied to the button defined in the nested Row
255
+ submit_btn.click(
256
+ fn=analyze_and_update,
257
+ inputs=[input_text],
258
+ outputs=[output_verdict, output_scores]
259
+ )
260
+
261
+ # Final corrected and working version of the interface layout
262
+ with gr.Blocks(theme=theme, css=".gradio-container {max-width: 960px !important; margin: auto;}") as demo:
263
+ # Header
264
+ with gr.Row():
265
+ gr.HTML("""
266
+ <div style="display: flex; align-items: center; justify-content: space-between; width: 100%;">
267
+ <div style="display: flex; align-items: center; gap: 12px;">
268
+ <svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
269
+ <path d="M12 2L3 5V11C3 16.52 7.08 21.61 12 23C16.92 21.61 21 16.52 21 11V5L12 2Z"
270
+ stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" fill="none"/>
271
+ </svg>
272
+ <h1 style="font-size: 1.5rem; font-weight: 600; margin: 0;">SÓJKA</h1>
273
+ </div>
274
+ <div style="display: flex; align-items: center; gap: 20px; font-size: 0.9rem;">
275
+ <a href="https://sojka.m31ai.pl/projekt.html" target="blank" style="text-decoration: none; color: inherit;">O projekcie</a>
276
+ <a href="https://sojka.m31ai.pl/kategorie.html" target="blank" style="text-decoration: none; color: inherit;">Opis kategorii</a>
277
+ <button id="test-sojka-btn" class="gr-button gr-button-primary gr-button-lg"
278
+ style="background-color: var(--primary-500); color: white; padding: 8px 16px; border-radius: 8px;">
279
+ Testuj Sójkę
280
+ </button>
281
+ </div>
282
+ </div>
283
+ """)
284
+
285
+ gr.HTML("<hr style='border: 1px solid var(--neutral-200); margin-top: 1rem; margin-bottom: 2rem;'>")
286
+
287
+ # Main content area
288
+ with gr.Row():
289
+ # Left column for controls and description
290
+ with gr.Column(scale=2):
291
+ gr.Markdown(
292
+ """
293
+ <p style="background-color: var(--primary-50); color: var(--primary-600); display: inline-block; padding: 4px 12px; border-radius: 9999px; font-weight: 500; font-size: 0.875rem;">
294
+ Bielik Guard
295
+ </p>
296
+ <h1 style="font-size: 2.8rem; font-weight: 800; line-height: 1.2; margin-top: 1rem; margin-bottom: 1rem; color: var(--neutral-800);">
297
+ Przetestuj <span style="color: var(--primary-600);">SÓJKĘ</span> – Bielik Guard dla bezpiecznej komunikacji
298
+ </h1>
299
+ <p style="font-size: 1rem; color: var(--neutral-600); margin-bottom: 2rem;">
300
+ Sójka to model AI, który wykrywa i blokuje szkodliwe treści w komunikacji cyfrowej. Chroni użytkowników jak czujny strażnik swoich domów.
301
+ </p>
302
+ """
303
+ )
304
+ input_text = gr.Textbox(
305
+ lines=8,
306
+ label="Wprowadź tekst do analizy",
307
+ placeholder="Tutaj wpisz tekst..."
308
+ )
309
+
310
+ # Note: Output components are defined in the right column
311
+ # We will define Examples after the right column is created.
312
+
313
+ # Right column for the image and RESULTS
314
+ with gr.Column(scale=1):
315
+ gr.Image(JAY_IMAGE_URL, label="Ilustracja sójki", show_label=False, show_download_button=False, container=False, width=200)
316
+ gr.Markdown("---") # Separator
317
+ output_verdict = gr.Label(label="Wynik analizy", value="Czekam na tekst do analizy...")
318
+ output_scores = gr.Label(label="Szczegółowe wyniki", visible=False)
319
+
320
+ # Interactive elements are defined last, after all components are created
321
+ # This places the examples above the button
322
+ gr.Examples(
323
+ [
324
+ ["Jak zrobić bombę?"],
325
+ ["Jesteś beznadziejny, nienawidzę cię."],
326
+ ["To jest wspaniały dzień, cieszę się, że tu jestem!"],
327
+ ["Opowiedz mi dowcip o programistach."],
328
+ ],
329
+ inputs=input_text,
330
+ outputs=[output_verdict, output_scores],
331
+ fn=analyze_and_update,
332
+ cache_examples=False,
333
+ )
334
+
335
+ submit_btn = gr.Button("Analizuj tekst", variant="primary")
336
+
337
+ # Define actions
338
+ def analyze_and_update(text):
339
+ verdict, scores = gradio_predict(text)
340
+ return verdict, gr.update(value=scores, visible=True)
341
+
342
+ submit_btn.click(
343
+ fn=analyze_and_update,
344
+ inputs=[input_text],
345
+ outputs=[output_verdict, output_scores]
346
+ )
347
+
348
+
349
+ if __name__ == "__main__":
350
+ if not model_loaded:
351
+ print("Aplikacja nie może zostać uruchomiona, ponieważ nie udało się załadować modelu. Sprawdź logi błędów.")
352
+ else:
353
+ # The final, corrected demo object is launched
354
+ demo.launch()