reginafeles commited on
Commit
f9272e7
·
verified ·
1 Parent(s): 245b58a

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +12 -0
  2. app.py +101 -0
  3. model/transformer_classifier.pt +3 -0
  4. static/script.js +104 -0
  5. templates/index.html +79 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install flask torch
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import re
5
+ import csv
6
+
7
+
8
+ class TransformerClassifier(torch.nn.Module):
9
+ def __init__(self, vocab_size, num_classes, d_model=64, nhead=4, num_layers=2, dim_feedforward=128):
10
+ super().__init__()
11
+ self.embedding = torch.nn.Embedding(vocab_size, d_model)
12
+ encoder_layer = torch.nn.TransformerEncoderLayer(
13
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward
14
+ )
15
+ self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
+ self.classifier = torch.nn.Linear(d_model, num_classes)
17
+
18
+ def forward(self, x):
19
+ emb = self.embedding(x)
20
+ emb = emb.permute(1, 0, 2)
21
+ encoded = self.encoder(emb)
22
+ cls = encoded[0]
23
+ return self.classifier(cls)
24
+
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ checkpoint = torch.load("model/transformer_classifier.pt", map_location=device)
27
+ vocab = checkpoint['vocab']
28
+ class_to_idx = checkpoint['class_to_idx']
29
+ idx_to_class = {i: name for name, i in class_to_idx.items()}
30
+
31
+ model = TransformerClassifier(len(vocab), len(class_to_idx))
32
+ model.load_state_dict(checkpoint['model_state_dict'])
33
+ model.to(device)
34
+ model.eval()
35
+
36
+ def tokenize(text):
37
+ text = text.lower()
38
+ text = re.sub(r"[^a-z0-9\s]", "", text)
39
+ return text.split()
40
+
41
+ def encode(text, max_len=32):
42
+ tokens = ['<CLS>'] + tokenize(text)
43
+ ids = [vocab.get(t, vocab['<UNK>']) for t in tokens]
44
+ if len(ids) < max_len:
45
+ ids += [vocab['<PAD>']] * (max_len - len(ids))
46
+ else:
47
+ ids = ids[:max_len]
48
+ return ids
49
+
50
+ app = Flask(__name__)
51
+
52
+ @app.route('/')
53
+ def index():
54
+ return render_template('index.html')
55
+
56
+ @app.route('/predict', methods=['POST'])
57
+ def predict():
58
+ data = request.get_json()
59
+ text = data['text']
60
+ encoded = encode(text)
61
+ x = torch.tensor(encoded).unsqueeze(0).to(device)
62
+ with torch.no_grad():
63
+ logits = model(x)
64
+ probs = F.softmax(logits, dim=1)
65
+ pred_idx = torch.argmax(probs, dim=1).item()
66
+ confidence = probs[0, pred_idx].item()
67
+ return jsonify({
68
+ 'label': idx_to_class[pred_idx],
69
+ 'confidence': round(confidence, 3),
70
+ 'probs': {idx_to_class[i]: round(p.item(), 3) for i, p in enumerate(probs[0])}
71
+ })
72
+ @app.route('/feedback', methods=['POST'])
73
+ def feedback():
74
+ data = request.get_json()
75
+ text = data.get('text')
76
+ predicted_label = data.get('predicted_label')
77
+ correct = data.get('correct')
78
+
79
+ with open('feedback.txt', 'a') as f:
80
+ f.write(f'Text: {text}\nPredicted Label: {predicted_label}\nCorrect: {correct}\n\n')
81
+
82
+ return jsonify({'status': 'Feedback received'})
83
+
84
+ @app.route('/correction', methods=['POST'])
85
+ def correction():
86
+ data = request.get_json()
87
+ text = data.get('text')
88
+ predicted_label = data.get('predicted_label')
89
+ true_label = data.get('true_label')
90
+
91
+ with open('corrections.csv', mode='a', newline='', encoding='utf-8') as file:
92
+ writer = csv.writer(file)
93
+ if file.tell() == 0:
94
+ writer.writerow(['Text', 'Predicted Label', 'True Label'])
95
+ writer.writerow([text, predicted_label, true_label])
96
+
97
+ return jsonify({'status': 'received'})
98
+
99
+
100
+ if __name__ == '__main__':
101
+ app.run(host='0.0.0.0', port=7860)
model/transformer_classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60667a05c642292a96edf702830ea85be1fb49af56bf606dddeabe594020ff5c
3
+ size 3036310
static/script.js ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', () => {
2
+ const btn = document.getElementById('submit-btn');
3
+ const input = document.getElementById('text-input');
4
+ const resultBox = document.getElementById('result-box');
5
+
6
+ async function sendFeedback(isCorrect) {
7
+ const text = input.value.trim();
8
+ const predictedLabel = resultBox.textContent.split(":")[1].split(",")[0].trim();
9
+ await fetch('/feedback', {
10
+ method: 'POST',
11
+ headers: { 'Content-Type': 'application/json' },
12
+ body: JSON.stringify({ text, predicted_label: predictedLabel, correct: isCorrect })
13
+ });
14
+ }
15
+
16
+ btn.addEventListener('click', async () => {
17
+ const text = input.value.trim();
18
+ if (!text) {
19
+ resultBox.textContent = "Введите текст";
20
+ resultBox.className = "alert alert-warning mt-4";
21
+ return;
22
+ }
23
+
24
+ resultBox.textContent = "⏳ Анализ...";
25
+ resultBox.className = "alert alert-info mt-4";
26
+
27
+ try {
28
+ const response = await fetch('/predict', {
29
+ method: 'POST',
30
+ headers: { 'Content-Type': 'application/json' },
31
+ body: JSON.stringify({ text })
32
+ });
33
+
34
+ const data = await response.json();
35
+ resultBox.textContent = `🏷 Класс: ${data.label}, уверенность: ${data.confidence}`;
36
+ resultBox.className = "alert alert-success mt-4";
37
+
38
+ document.getElementById("feedback-section").classList.remove("d-none");
39
+ document.getElementById("probs-container").style.display = 'block';
40
+ const probsBars = document.getElementById('probs-bars');
41
+ probsBars.innerHTML = '';
42
+ for (const [label, prob] of Object.entries(data.probs)) {
43
+ const bar = document.createElement('div');
44
+ bar.className = 'progress my-2';
45
+
46
+ const inner = document.createElement('div');
47
+ inner.className = 'progress-bar';
48
+ inner.style.width = `${prob * 100}%`;
49
+ inner.style.color = 'black';
50
+ inner.innerText = `${label}: ${(prob * 100).toFixed(1)}%`;
51
+
52
+
53
+ bar.appendChild(inner);
54
+ probsBars.appendChild(bar);
55
+ }
56
+ } catch (err) {
57
+ resultBox.textContent = "Ошибка при анализе текста";
58
+ resultBox.className = "alert alert-danger mt-4";
59
+ }
60
+ });
61
+
62
+ document.getElementById("btn-correct").addEventListener("click", async () => {
63
+ await sendFeedback(true);
64
+ document.getElementById("feedback-section").classList.add("d-none");
65
+ resultBox.innerHTML += "<br>Спасибо за подтверждение!";
66
+ });
67
+
68
+ document.getElementById("btn-incorrect").addEventListener("click", () => {
69
+ sendFeedback(false);
70
+ document.getElementById("correction-form").classList.remove("d-none");
71
+ });
72
+
73
+ document.getElementById("send-correction").addEventListener("click", () => {
74
+ const correctLabel = document.getElementById("correct-label").value;
75
+ const text = document.getElementById("text-input").value.trim();
76
+ const predictedLabel = resultBox.textContent.split(":")[1].split(",")[0].trim();
77
+
78
+ if (!correctLabel) {
79
+ alert("Выберите правильный класс");
80
+ return;
81
+ }
82
+
83
+ fetch('/correction', {
84
+ method: 'POST',
85
+ headers: { 'Content-Type': 'application/json' },
86
+ body: JSON.stringify({ text, predicted_label: predictedLabel, true_label: correctLabel })
87
+ }).then(() => {
88
+ alert("Спасибо, исправление отправлено");
89
+ document.getElementById("correction-form").classList.add("d-none");
90
+ });
91
+ });
92
+
93
+ const themeToggle = document.getElementById("theme-toggle");
94
+ themeToggle.addEventListener("change", () => {
95
+ document.body.classList.toggle("bg-dark");
96
+ document.body.classList.toggle("text-white");
97
+ document.querySelector(".container").classList.toggle("bg-dark");
98
+ document.querySelector(".container").classList.toggle("text-white");
99
+ document.body.classList.toggle("dark-theme");
100
+ document.querySelectorAll('h1, .form-label').forEach(el => {
101
+ el.classList.toggle('text-dark');
102
+ })
103
+ });
104
+ });
templates/index.html ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="ru">
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <title>Классификация кибербуллинга</title>
6
+ <meta name="viewport" content="width=device-width, initial-scale=1">
7
+
8
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/css/bootstrap.min.css" rel="stylesheet">
9
+ <script defer src="/static/script.js"></script>
10
+
11
+ <style>
12
+ body {
13
+ background-color: #f8f9fa;
14
+ }
15
+ .container {
16
+ max-width: 700px;
17
+ margin-top: 5%;
18
+ }
19
+ .result-box {
20
+ margin-top: 1rem;
21
+ font-weight: bold;
22
+ }
23
+ </style>
24
+ </head>
25
+ <body>
26
+ <nav class="navbar navbar-expand-lg navbar-dark bg-dark">
27
+ <div class="container-fluid">
28
+ <a class="navbar-brand" href="#">Cyberbullying AI</a>
29
+ <div class="form-check form-switch ms-auto me-3 text-white">
30
+ <input class="form-check-input" type="checkbox" id="theme-toggle">
31
+ <label class="form-check-label" for="theme-toggle">🌙 Тёмная тема</label>
32
+ </div>
33
+ </div>
34
+ </nav>
35
+
36
+ <div class="container bg-white p-5 rounded shadow-sm">
37
+ <h1 class="mb-4 text-center">Классификация кибербуллинга</h1>
38
+
39
+ <div class="mb-3">
40
+ <label for="text-input" class="form-label">Введите текст</label>
41
+ <textarea class="form-control" id="text-input" rows="4" placeholder="Напишите сообщение..."></textarea>
42
+ </div>
43
+
44
+ <div class="d-grid">
45
+ <button class="btn btn-primary" id="submit-btn">Анализировать</button>
46
+ </div>
47
+
48
+ <div class="alert alert-info mt-4 d-none" id="result-box"></div>
49
+ </div>
50
+
51
+ <div id="feedback-section" class="mt-3 d-none">
52
+ <p>Был ли результат точным?</p>
53
+ <button class="btn btn-sm btn-success me-2" id="btn-correct">✅ Да</button>
54
+ <button class="btn btn-sm btn-danger" id="btn-incorrect">❌ Нет</button>
55
+ </div>
56
+
57
+ <div id="correction-form" class="mt-3 d-none">
58
+ <label for="correct-label" class="form-label">Выберите правильный класс:</label>
59
+ <select id="correct-label" class="form-select mb-2">
60
+ <option value="" disabled selected>-- Выберите класс --</option>
61
+ <option value="age">age</option>
62
+ <option value="ethnicity">ethnicity</option>
63
+ <option value="gender">gender</option>
64
+ <option value="not_cyberbullying">not_cyberbullying</option>
65
+ <option value="other_cyberbullying">other_cyberbullying</option>
66
+ <option value="religion">religion</option>
67
+ </select>
68
+ <button class="btn btn-primary btn-sm" id="send-correction">Отправить</button>
69
+ </div>
70
+
71
+ <div class="mt-4" id="probs-container" style="display: none;">
72
+ <h5>Вероятности по классам:</h5>
73
+ <div id="probs-bars"></div>
74
+ </div>
75
+
76
+
77
+ <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/js/bootstrap.bundle.min.js"></script>
78
+ </body>
79
+ </html>