Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- Dockerfile +12 -0
- app.py +101 -0
- model/transformer_classifier.pt +3 -0
- static/script.js +104 -0
- templates/index.html +79 -0
Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY . .
|
| 6 |
+
|
| 7 |
+
RUN pip install --upgrade pip
|
| 8 |
+
RUN pip install flask torch
|
| 9 |
+
|
| 10 |
+
EXPOSE 7860
|
| 11 |
+
|
| 12 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, render_template, request, jsonify
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import re
|
| 5 |
+
import csv
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TransformerClassifier(torch.nn.Module):
|
| 9 |
+
def __init__(self, vocab_size, num_classes, d_model=64, nhead=4, num_layers=2, dim_feedforward=128):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.embedding = torch.nn.Embedding(vocab_size, d_model)
|
| 12 |
+
encoder_layer = torch.nn.TransformerEncoderLayer(
|
| 13 |
+
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward
|
| 14 |
+
)
|
| 15 |
+
self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 16 |
+
self.classifier = torch.nn.Linear(d_model, num_classes)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
emb = self.embedding(x)
|
| 20 |
+
emb = emb.permute(1, 0, 2)
|
| 21 |
+
encoded = self.encoder(emb)
|
| 22 |
+
cls = encoded[0]
|
| 23 |
+
return self.classifier(cls)
|
| 24 |
+
|
| 25 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 26 |
+
checkpoint = torch.load("model/transformer_classifier.pt", map_location=device)
|
| 27 |
+
vocab = checkpoint['vocab']
|
| 28 |
+
class_to_idx = checkpoint['class_to_idx']
|
| 29 |
+
idx_to_class = {i: name for name, i in class_to_idx.items()}
|
| 30 |
+
|
| 31 |
+
model = TransformerClassifier(len(vocab), len(class_to_idx))
|
| 32 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 33 |
+
model.to(device)
|
| 34 |
+
model.eval()
|
| 35 |
+
|
| 36 |
+
def tokenize(text):
|
| 37 |
+
text = text.lower()
|
| 38 |
+
text = re.sub(r"[^a-z0-9\s]", "", text)
|
| 39 |
+
return text.split()
|
| 40 |
+
|
| 41 |
+
def encode(text, max_len=32):
|
| 42 |
+
tokens = ['<CLS>'] + tokenize(text)
|
| 43 |
+
ids = [vocab.get(t, vocab['<UNK>']) for t in tokens]
|
| 44 |
+
if len(ids) < max_len:
|
| 45 |
+
ids += [vocab['<PAD>']] * (max_len - len(ids))
|
| 46 |
+
else:
|
| 47 |
+
ids = ids[:max_len]
|
| 48 |
+
return ids
|
| 49 |
+
|
| 50 |
+
app = Flask(__name__)
|
| 51 |
+
|
| 52 |
+
@app.route('/')
|
| 53 |
+
def index():
|
| 54 |
+
return render_template('index.html')
|
| 55 |
+
|
| 56 |
+
@app.route('/predict', methods=['POST'])
|
| 57 |
+
def predict():
|
| 58 |
+
data = request.get_json()
|
| 59 |
+
text = data['text']
|
| 60 |
+
encoded = encode(text)
|
| 61 |
+
x = torch.tensor(encoded).unsqueeze(0).to(device)
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
logits = model(x)
|
| 64 |
+
probs = F.softmax(logits, dim=1)
|
| 65 |
+
pred_idx = torch.argmax(probs, dim=1).item()
|
| 66 |
+
confidence = probs[0, pred_idx].item()
|
| 67 |
+
return jsonify({
|
| 68 |
+
'label': idx_to_class[pred_idx],
|
| 69 |
+
'confidence': round(confidence, 3),
|
| 70 |
+
'probs': {idx_to_class[i]: round(p.item(), 3) for i, p in enumerate(probs[0])}
|
| 71 |
+
})
|
| 72 |
+
@app.route('/feedback', methods=['POST'])
|
| 73 |
+
def feedback():
|
| 74 |
+
data = request.get_json()
|
| 75 |
+
text = data.get('text')
|
| 76 |
+
predicted_label = data.get('predicted_label')
|
| 77 |
+
correct = data.get('correct')
|
| 78 |
+
|
| 79 |
+
with open('feedback.txt', 'a') as f:
|
| 80 |
+
f.write(f'Text: {text}\nPredicted Label: {predicted_label}\nCorrect: {correct}\n\n')
|
| 81 |
+
|
| 82 |
+
return jsonify({'status': 'Feedback received'})
|
| 83 |
+
|
| 84 |
+
@app.route('/correction', methods=['POST'])
|
| 85 |
+
def correction():
|
| 86 |
+
data = request.get_json()
|
| 87 |
+
text = data.get('text')
|
| 88 |
+
predicted_label = data.get('predicted_label')
|
| 89 |
+
true_label = data.get('true_label')
|
| 90 |
+
|
| 91 |
+
with open('corrections.csv', mode='a', newline='', encoding='utf-8') as file:
|
| 92 |
+
writer = csv.writer(file)
|
| 93 |
+
if file.tell() == 0:
|
| 94 |
+
writer.writerow(['Text', 'Predicted Label', 'True Label'])
|
| 95 |
+
writer.writerow([text, predicted_label, true_label])
|
| 96 |
+
|
| 97 |
+
return jsonify({'status': 'received'})
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
app.run(host='0.0.0.0', port=7860)
|
model/transformer_classifier.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60667a05c642292a96edf702830ea85be1fb49af56bf606dddeabe594020ff5c
|
| 3 |
+
size 3036310
|
static/script.js
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
document.addEventListener('DOMContentLoaded', () => {
|
| 2 |
+
const btn = document.getElementById('submit-btn');
|
| 3 |
+
const input = document.getElementById('text-input');
|
| 4 |
+
const resultBox = document.getElementById('result-box');
|
| 5 |
+
|
| 6 |
+
async function sendFeedback(isCorrect) {
|
| 7 |
+
const text = input.value.trim();
|
| 8 |
+
const predictedLabel = resultBox.textContent.split(":")[1].split(",")[0].trim();
|
| 9 |
+
await fetch('/feedback', {
|
| 10 |
+
method: 'POST',
|
| 11 |
+
headers: { 'Content-Type': 'application/json' },
|
| 12 |
+
body: JSON.stringify({ text, predicted_label: predictedLabel, correct: isCorrect })
|
| 13 |
+
});
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
btn.addEventListener('click', async () => {
|
| 17 |
+
const text = input.value.trim();
|
| 18 |
+
if (!text) {
|
| 19 |
+
resultBox.textContent = "Введите текст";
|
| 20 |
+
resultBox.className = "alert alert-warning mt-4";
|
| 21 |
+
return;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
resultBox.textContent = "⏳ Анализ...";
|
| 25 |
+
resultBox.className = "alert alert-info mt-4";
|
| 26 |
+
|
| 27 |
+
try {
|
| 28 |
+
const response = await fetch('/predict', {
|
| 29 |
+
method: 'POST',
|
| 30 |
+
headers: { 'Content-Type': 'application/json' },
|
| 31 |
+
body: JSON.stringify({ text })
|
| 32 |
+
});
|
| 33 |
+
|
| 34 |
+
const data = await response.json();
|
| 35 |
+
resultBox.textContent = `🏷 Класс: ${data.label}, уверенность: ${data.confidence}`;
|
| 36 |
+
resultBox.className = "alert alert-success mt-4";
|
| 37 |
+
|
| 38 |
+
document.getElementById("feedback-section").classList.remove("d-none");
|
| 39 |
+
document.getElementById("probs-container").style.display = 'block';
|
| 40 |
+
const probsBars = document.getElementById('probs-bars');
|
| 41 |
+
probsBars.innerHTML = '';
|
| 42 |
+
for (const [label, prob] of Object.entries(data.probs)) {
|
| 43 |
+
const bar = document.createElement('div');
|
| 44 |
+
bar.className = 'progress my-2';
|
| 45 |
+
|
| 46 |
+
const inner = document.createElement('div');
|
| 47 |
+
inner.className = 'progress-bar';
|
| 48 |
+
inner.style.width = `${prob * 100}%`;
|
| 49 |
+
inner.style.color = 'black';
|
| 50 |
+
inner.innerText = `${label}: ${(prob * 100).toFixed(1)}%`;
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
bar.appendChild(inner);
|
| 54 |
+
probsBars.appendChild(bar);
|
| 55 |
+
}
|
| 56 |
+
} catch (err) {
|
| 57 |
+
resultBox.textContent = "Ошибка при анализе текста";
|
| 58 |
+
resultBox.className = "alert alert-danger mt-4";
|
| 59 |
+
}
|
| 60 |
+
});
|
| 61 |
+
|
| 62 |
+
document.getElementById("btn-correct").addEventListener("click", async () => {
|
| 63 |
+
await sendFeedback(true);
|
| 64 |
+
document.getElementById("feedback-section").classList.add("d-none");
|
| 65 |
+
resultBox.innerHTML += "<br>Спасибо за подтверждение!";
|
| 66 |
+
});
|
| 67 |
+
|
| 68 |
+
document.getElementById("btn-incorrect").addEventListener("click", () => {
|
| 69 |
+
sendFeedback(false);
|
| 70 |
+
document.getElementById("correction-form").classList.remove("d-none");
|
| 71 |
+
});
|
| 72 |
+
|
| 73 |
+
document.getElementById("send-correction").addEventListener("click", () => {
|
| 74 |
+
const correctLabel = document.getElementById("correct-label").value;
|
| 75 |
+
const text = document.getElementById("text-input").value.trim();
|
| 76 |
+
const predictedLabel = resultBox.textContent.split(":")[1].split(",")[0].trim();
|
| 77 |
+
|
| 78 |
+
if (!correctLabel) {
|
| 79 |
+
alert("Выберите правильный класс");
|
| 80 |
+
return;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
fetch('/correction', {
|
| 84 |
+
method: 'POST',
|
| 85 |
+
headers: { 'Content-Type': 'application/json' },
|
| 86 |
+
body: JSON.stringify({ text, predicted_label: predictedLabel, true_label: correctLabel })
|
| 87 |
+
}).then(() => {
|
| 88 |
+
alert("Спасибо, исправление отправлено");
|
| 89 |
+
document.getElementById("correction-form").classList.add("d-none");
|
| 90 |
+
});
|
| 91 |
+
});
|
| 92 |
+
|
| 93 |
+
const themeToggle = document.getElementById("theme-toggle");
|
| 94 |
+
themeToggle.addEventListener("change", () => {
|
| 95 |
+
document.body.classList.toggle("bg-dark");
|
| 96 |
+
document.body.classList.toggle("text-white");
|
| 97 |
+
document.querySelector(".container").classList.toggle("bg-dark");
|
| 98 |
+
document.querySelector(".container").classList.toggle("text-white");
|
| 99 |
+
document.body.classList.toggle("dark-theme");
|
| 100 |
+
document.querySelectorAll('h1, .form-label').forEach(el => {
|
| 101 |
+
el.classList.toggle('text-dark');
|
| 102 |
+
})
|
| 103 |
+
});
|
| 104 |
+
});
|
templates/index.html
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="ru">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8">
|
| 5 |
+
<title>Классификация кибербуллинга</title>
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 7 |
+
|
| 8 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
| 9 |
+
<script defer src="/static/script.js"></script>
|
| 10 |
+
|
| 11 |
+
<style>
|
| 12 |
+
body {
|
| 13 |
+
background-color: #f8f9fa;
|
| 14 |
+
}
|
| 15 |
+
.container {
|
| 16 |
+
max-width: 700px;
|
| 17 |
+
margin-top: 5%;
|
| 18 |
+
}
|
| 19 |
+
.result-box {
|
| 20 |
+
margin-top: 1rem;
|
| 21 |
+
font-weight: bold;
|
| 22 |
+
}
|
| 23 |
+
</style>
|
| 24 |
+
</head>
|
| 25 |
+
<body>
|
| 26 |
+
<nav class="navbar navbar-expand-lg navbar-dark bg-dark">
|
| 27 |
+
<div class="container-fluid">
|
| 28 |
+
<a class="navbar-brand" href="#">Cyberbullying AI</a>
|
| 29 |
+
<div class="form-check form-switch ms-auto me-3 text-white">
|
| 30 |
+
<input class="form-check-input" type="checkbox" id="theme-toggle">
|
| 31 |
+
<label class="form-check-label" for="theme-toggle">🌙 Тёмная тема</label>
|
| 32 |
+
</div>
|
| 33 |
+
</div>
|
| 34 |
+
</nav>
|
| 35 |
+
|
| 36 |
+
<div class="container bg-white p-5 rounded shadow-sm">
|
| 37 |
+
<h1 class="mb-4 text-center">Классификация кибербуллинга</h1>
|
| 38 |
+
|
| 39 |
+
<div class="mb-3">
|
| 40 |
+
<label for="text-input" class="form-label">Введите текст</label>
|
| 41 |
+
<textarea class="form-control" id="text-input" rows="4" placeholder="Напишите сообщение..."></textarea>
|
| 42 |
+
</div>
|
| 43 |
+
|
| 44 |
+
<div class="d-grid">
|
| 45 |
+
<button class="btn btn-primary" id="submit-btn">Анализировать</button>
|
| 46 |
+
</div>
|
| 47 |
+
|
| 48 |
+
<div class="alert alert-info mt-4 d-none" id="result-box"></div>
|
| 49 |
+
</div>
|
| 50 |
+
|
| 51 |
+
<div id="feedback-section" class="mt-3 d-none">
|
| 52 |
+
<p>Был ли результат точным?</p>
|
| 53 |
+
<button class="btn btn-sm btn-success me-2" id="btn-correct">✅ Да</button>
|
| 54 |
+
<button class="btn btn-sm btn-danger" id="btn-incorrect">❌ Нет</button>
|
| 55 |
+
</div>
|
| 56 |
+
|
| 57 |
+
<div id="correction-form" class="mt-3 d-none">
|
| 58 |
+
<label for="correct-label" class="form-label">Выберите правильный класс:</label>
|
| 59 |
+
<select id="correct-label" class="form-select mb-2">
|
| 60 |
+
<option value="" disabled selected>-- Выберите класс --</option>
|
| 61 |
+
<option value="age">age</option>
|
| 62 |
+
<option value="ethnicity">ethnicity</option>
|
| 63 |
+
<option value="gender">gender</option>
|
| 64 |
+
<option value="not_cyberbullying">not_cyberbullying</option>
|
| 65 |
+
<option value="other_cyberbullying">other_cyberbullying</option>
|
| 66 |
+
<option value="religion">religion</option>
|
| 67 |
+
</select>
|
| 68 |
+
<button class="btn btn-primary btn-sm" id="send-correction">Отправить</button>
|
| 69 |
+
</div>
|
| 70 |
+
|
| 71 |
+
<div class="mt-4" id="probs-container" style="display: none;">
|
| 72 |
+
<h5>Вероятности по классам:</h5>
|
| 73 |
+
<div id="probs-bars"></div>
|
| 74 |
+
</div>
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.3/dist/js/bootstrap.bundle.min.js"></script>
|
| 78 |
+
</body>
|
| 79 |
+
</html>
|