File size: 611 Bytes
5c5b473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from src.env.moderation_env import ModerationEnv
from src.agent.dqn_agent import DQNAgent
from src.training.train_rl import train

# 🧪 small dataset
data = [
    ("I love this", "allow"),
    ("you are stupid", "flag"),
    ("I will kill you", "remove"),
    ("this is garbage", "flag"),
    ("great job!", "allow"),
]

env = ModerationEnv(data)

agent = DQNAgent(
    action_space=["allow", "flag", "remove"],
    state_size=4
)

# 🔥 train
train(env, agent, episodes=100)

# 💾 save model
import torch
torch.save(agent.model.state_dict(), "dqn_model.pth")

print("✅ Training complete + model saved!")