LorenzoBioinfo commited on
Commit
a66d87f
·
1 Parent(s): 0ac2632

Train model also on youtube data and admin page

Browse files
app_templates/admin.html ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="it">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>⚙️ Pannello Admin - Sentiment App</title>
6
+ <style>
7
+ body {
8
+ font-family: "Segoe UI", sans-serif;
9
+ background-color: #f4f6fa;
10
+ margin: 0;
11
+ padding: 2rem;
12
+ }
13
+ .container {
14
+ max-width: 800px;
15
+ margin: 0 auto;
16
+ background: #fff;
17
+ padding: 2rem;
18
+ border-radius: 10px;
19
+ box-shadow: 0 3px 10px rgba(0,0,0,0.1);
20
+ }
21
+ h1 {
22
+ color: #0052cc;
23
+ text-align: center;
24
+ }
25
+ .section {
26
+ margin-top: 1.5rem;
27
+ }
28
+ button, a.button {
29
+ background-color: #0052cc;
30
+ color: white;
31
+ border: none;
32
+ padding: 10px 18px;
33
+ border-radius: 8px;
34
+ cursor: pointer;
35
+ text-decoration: none;
36
+ font-weight: 500;
37
+ margin-top: 10px;
38
+ display: inline-block;
39
+ }
40
+ button:hover, a.button:hover {
41
+ background-color: #003d99;
42
+ }
43
+ .metrics-link {
44
+ display: block;
45
+ text-align: center;
46
+ margin-top: 1.5rem;
47
+ font-weight: bold;
48
+ }
49
+ .back {
50
+ display: inline-block;
51
+ margin-top: 2rem;
52
+ text-decoration: none;
53
+ color: #0052cc;
54
+ text-align: center;
55
+ width: 100%;
56
+ }
57
+ </style>
58
+ </head>
59
+ <body>
60
+ <div class="container">
61
+ <h1>⚙️ Pannello di Amministrazione</h1>
62
+ <p style="text-align:center;">Gestisci il modello di analisi del sentiment, il training e il monitoring.</p>
63
+
64
+ <div class="section">
65
+ <h3>🎓 Training del Modello</h3>
66
+ <p>Avvia un nuovo training usando i dati <strong>TweetEval</strong>.</p>
67
+ <form action="/admin/train" method="post">
68
+ <button type="submit">Esegui Training</button>
69
+ </form>
70
+ </div>
71
+
72
+ <div class="section">
73
+ <h3>📊 Monitoring</h3>
74
+ <p>Analizza le performance del modello sui dataset disponibili.</p>
75
+ <form action="/admin/monitoring" method="post">
76
+ <button type="submit">Esegui Monitoring</button>
77
+ </form>
78
+
79
+ <a href="/admin/metrics" class="metrics-link button">📈 Visualizza Metriche</a>
80
+ </div>
81
+
82
+ <a class="back" href="/">← Torna alla Home</a>
83
+ </div>
84
+ </body>
85
+ </html>
app_templates/metrics.html ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="it">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>📈 Metriche del Modello</title>
6
+ <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
7
+ <style>
8
+ body {
9
+ font-family: "Segoe UI", sans-serif;
10
+ background-color: #f5f7fa;
11
+ margin: 0;
12
+ padding: 2rem;
13
+ }
14
+ .container {
15
+ max-width: 800px;
16
+ margin: 0 auto;
17
+ background: white;
18
+ padding: 2rem;
19
+ border-radius: 10px;
20
+ box-shadow: 0 3px 10px rgba(0,0,0,0.1);
21
+ }
22
+ h1 {
23
+ color: #0052cc;
24
+ text-align: center;
25
+ }
26
+ table {
27
+ width: 100%;
28
+ border-collapse: collapse;
29
+ margin-top: 1.5rem;
30
+ }
31
+ th, td {
32
+ text-align: left;
33
+ padding: 10px;
34
+ border-bottom: 1px solid #ddd;
35
+ }
36
+ th {
37
+ background-color: #0052cc;
38
+ color: white;
39
+ }
40
+ canvas {
41
+ margin-top: 30px;
42
+ width: 100%;
43
+ height: 300px;
44
+ }
45
+ .button {
46
+ background-color: #0052cc;
47
+ color: white;
48
+ border: none;
49
+ padding: 10px 18px;
50
+ border-radius: 8px;
51
+ cursor: pointer;
52
+ text-decoration: none;
53
+ font-weight: 500;
54
+ display: inline-block;
55
+ margin-top: 1rem;
56
+ }
57
+ .button:hover {
58
+ background-color: #003d99;
59
+ }
60
+ .back {
61
+ display: block;
62
+ text-align: center;
63
+ margin-top: 2rem;
64
+ text-decoration: none;
65
+ color: #0052cc;
66
+ }
67
+ </style>
68
+ </head>
69
+ <body>
70
+ <div class="container">
71
+ <h1>📊 Metriche del Modello</h1>
72
+
73
+ {% if metrics %}
74
+ <table>
75
+ <thead>
76
+ <tr>
77
+ <th>Metrica</th>
78
+ <th>Valore</th>
79
+ </tr>
80
+ </thead>
81
+ <tbody>
82
+ {% for key, value in metrics.items() %}
83
+ <tr>
84
+ <td>{{ key }}</td>
85
+ <td>{{ "%.3f"|format(value) }}</td>
86
+ </tr>
87
+ {% endfor %}
88
+ </tbody>
89
+ </table>
90
+
91
+ <canvas id="metricsChart"></canvas>
92
+
93
+ <script>
94
+ const ctx = document.getElementById('metricsChart').getContext('2d');
95
+ const labels = {{ metrics.keys() | list | tojson }};
96
+ const data = {{ metrics.values() | list | tojson }};
97
+ new Chart(ctx, {
98
+ type: 'bar',
99
+ data: {
100
+ labels: labels,
101
+ datasets: [{
102
+ label: 'Valori delle metriche',
103
+ data: data,
104
+ backgroundColor: 'rgba(0, 82, 204, 0.6)',
105
+ borderRadius: 6
106
+ }]
107
+ },
108
+ options: {
109
+ scales: {
110
+ y: { beginAtZero: true }
111
+ }
112
+ }
113
+ });
114
+ </script>
115
+
116
+ {% else %}
117
+ <p style="text-align:center;">Nessun dato disponibile. Esegui il monitoring per visualizzare le metriche.</p>
118
+ {% endif %}
119
+
120
+ <div style="text-align:center;">
121
+ <a class="button" href="/admin">← Torna all’Area Admin</a>
122
+ </div>
123
+ </div>
124
+ </body>
125
+ </html>
src/app.py CHANGED
@@ -8,6 +8,8 @@ from datasets import load_dataset, load_from_disk
8
  import torch
9
  import random
10
  import subprocess
 
 
11
 
12
  # Caricamento del modello e dei dati se già scaricati
13
  MODEL= "cardiffnlp/twitter-roberta-base-sentiment-latest"
@@ -123,6 +125,47 @@ def random_youtube_comment(request: Request):
123
  )
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  if __name__=="__main__":
127
  import uvicorn
128
  uvicorn.run(app,host="0.0.0.0",port=8000)
 
8
  import torch
9
  import random
10
  import subprocess
11
+ import json
12
+ import os
13
 
14
  # Caricamento del modello e dei dati se già scaricati
15
  MODEL= "cardiffnlp/twitter-roberta-base-sentiment-latest"
 
125
  )
126
 
127
 
128
+
129
+ @app.get("/admin", response_class=HTMLResponse)
130
+ async def admin_dashboard(request: Request):
131
+ """Pagina principale dell'area admin."""
132
+ metrics = None
133
+ metrics_path = "reports/metrics.json"
134
+ if os.path.exists(metrics_path):
135
+ with open(metrics_path, "r") as f:
136
+ metrics = json.load(f)
137
+ return templates.TemplateResponse(
138
+ "admin.html",
139
+ {"request": request, "metrics": metrics}
140
+ )
141
+
142
+ @app.post("/admin/train")
143
+ async def retrain_model():
144
+ """Lancia lo script di training."""
145
+ subprocess.run(["python", "src/train.py"], check=True)
146
+ return {"status": "Training completato"}
147
+
148
+ @app.post("/admin/monitor")
149
+ async def run_monitoring():
150
+ """Esegue il monitoring e aggiorna metrics.json."""
151
+ subprocess.run(["python", "src/monitoring.py"], check=True)
152
+ return {"status": "Monitoring completato"}
153
+
154
+ @app.get("/admin/metrics", response_class=HTMLResponse)
155
+ def view_metrics(request: Request):
156
+ """Visualizza i risultati del monitoring in forma tabellare e grafica."""
157
+ metrics_path = "reports/metrics.json"
158
+ metrics = None
159
+ if os.path.exists(metrics_path):
160
+ with open(metrics_path, "r") as f:
161
+ metrics = json.load(f)
162
+ return templates.TemplateResponse(
163
+ "metrics.html",
164
+ {"request": request, "metrics": metrics}
165
+ )
166
+
167
+
168
+
169
  if __name__=="__main__":
170
  import uvicorn
171
  uvicorn.run(app,host="0.0.0.0",port=8000)
src/monitoring.py CHANGED
@@ -5,7 +5,9 @@ import torch
5
  import numpy as np
6
  import json
7
  import os
 
8
 
 
9
  MODEL_PATH = "models/sentiment_model"
10
  TWEET_PATH = "data/processed/tweet_eval_tokenized"
11
  YT_PATH = "data/processed/youtube_comments"
@@ -31,6 +33,26 @@ def evaluate_model(model, tokenizer, dataset, dataset_name, sample_size=300):
31
  print(f"{dataset_name} — Accuracy: {acc:.3f}, F1: {f1:.3f}")
32
  return {"dataset": dataset_name, "accuracy": acc, "f1": f1, "confusion_matrix": cm}
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def main():
35
  print("Caricamento del modello")
36
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
@@ -52,5 +74,7 @@ def main():
52
 
53
  print(f"Risultati salvati in: {metrics_path}")
54
 
 
 
55
  if __name__ == "__main__":
56
  main()
 
5
  import numpy as np
6
  import json
7
  import os
8
+ from src.train_model import train_model
9
 
10
+ ACCURACY_THRESHOLD = 0.75
11
  MODEL_PATH = "models/sentiment_model"
12
  TWEET_PATH = "data/processed/tweet_eval_tokenized"
13
  YT_PATH = "data/processed/youtube_comments"
 
33
  print(f"{dataset_name} — Accuracy: {acc:.3f}, F1: {f1:.3f}")
34
  return {"dataset": dataset_name, "accuracy": acc, "f1": f1, "confusion_matrix": cm}
35
 
36
+
37
+ def retrain_on_youtube_sample():
38
+ from datasets import load_from_disk
39
+ youtube_data = load_from_disk(YT_PROCESSED_PATH)["train"]
40
+
41
+ youtube_sample = youtube_data.shuffle(seed=42).select(range(500))
42
+ train_model(additional_data=youtube_sample, output_dir=MODEL_OUTPUT_PATH)
43
+
44
+
45
+
46
+ def monitor_model():
47
+ metrics = evaluate_model_on_youtube()
48
+
49
+ print(f"Accuracy su YouTube: {metrics['accuracy']:.3f}")
50
+ if metrics["accuracy"] < ACCURACY_THRESHOLD:
51
+ print("Performance sotto la soglia. Avvio retraining parziale...")
52
+ retrain_on_youtube_sample()
53
+
54
+ return metrics
55
+
56
  def main():
57
  print("Caricamento del modello")
58
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
 
74
 
75
  print(f"Risultati salvati in: {metrics_path}")
76
 
77
+
78
+
79
  if __name__ == "__main__":
80
  main()
src/train_model.py CHANGED
@@ -5,7 +5,7 @@ from transformers import (
5
  TrainingArguments,
6
  AutoTokenizer
7
  )
8
- from datasets import load_from_disk
9
  import evaluate
10
  import numpy as np
11
  import os
@@ -24,9 +24,13 @@ def compute_metrics(eval_pred):
24
  f1 = metric_f1.compute(predictions=predictions, references=labels, average="weighted")
25
  return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
26
 
27
- def train_model(sample_train_size=1000, sample_eval_size=300):
 
28
  print("Caricamento dataset Tweet eval preprocessato")
29
  dataset = load_from_disk(DATA_PATH)
 
 
 
30
 
31
  #
32
  print(f"Riduzione dataset: {sample_train_size} per il train, {sample_eval_size} per la validazione.")
 
5
  TrainingArguments,
6
  AutoTokenizer
7
  )
8
+ from datasets import load_from_disk,concatenate_datasets
9
  import evaluate
10
  import numpy as np
11
  import os
 
24
  f1 = metric_f1.compute(predictions=predictions, references=labels, average="weighted")
25
  return {"accuracy": acc["accuracy"], "f1": f1["f1"]}
26
 
27
+
28
+ def train_model(additional_data=None,sample_train_size=1000, sample_eval_size=300):
29
  print("Caricamento dataset Tweet eval preprocessato")
30
  dataset = load_from_disk(DATA_PATH)
31
+ if additional_data is not None:
32
+ print("Aggiungo dati YouTube al training set...")
33
+ dataset["train"] = concatenate_datasets([dataset["train"], additional_data])
34
 
35
  #
36
  print(f"Riduzione dataset: {sample_train_size} per il train, {sample_eval_size} per la validazione.")