File size: 5,273 Bytes
b86ba0c
 
 
51356d9
b86ba0c
 
 
 
51356d9
 
 
b86ba0c
 
 
51356d9
 
b86ba0c
 
51356d9
b86ba0c
 
 
51356d9
 
 
 
 
b86ba0c
51356d9
b86ba0c
 
51356d9
b86ba0c
 
51356d9
b86ba0c
51356d9
b86ba0c
 
 
51356d9
f806ce7
b86ba0c
 
709e8f1
b86ba0c
 
 
 
51356d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709e8f1
b86ba0c
51356d9
 
b86ba0c
51356d9
 
 
b86ba0c
51356d9
 
 
 
 
 
 
 
 
 
 
 
b86ba0c
51356d9
 
 
 
b86ba0c
51356d9
 
 
 
 
 
 
 
 
 
b86ba0c
51356d9
 
 
 
 
b86ba0c
51356d9
b86ba0c
 
 
 
 
 
51356d9
 
 
 
 
 
 
 
 
 
 
 
 
 
b86ba0c
51356d9
0657516
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
import json
import os
from pathlib import Path
from model import (
    GreesyGPT, 
    generate_moderation, 
    GreesyTrainer,
    get_dataset,
    get_sample_dataset,
    ReasoningMode, 
    OutputFormat, 
    DEVICE,
    describe_reasoning_modes,
    DATASET_JSON_PATH
)

# 1. Initialize Model Global Instance
model = GreesyGPT()
weights_path = Path("greesy_gpt.pt")

def load_weights():
    if weights_path.exists():
        model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
        return f"Loaded weights from {weights_path}"
    return "No weights found. Model initialized with random parameters."

load_weights()
model.to(DEVICE)

# --- Inference Logic ---
def moderate(text, mode_str, format_str):
    if not text.strip():
        return "Please enter text.", ""

    model.eval()
    mode = ReasoningMode(mode_str.lower())
    fmt = OutputFormat(format_str.lower())

    result = generate_moderation(model, prompt=text, mode=mode, output_format=fmt)
    print(result)
    verdict_output = result["verdict_fmt"]
    if fmt == OutputFormat.JSON:
        verdict_output = f"```json\n{json.dumps(verdict_output, indent=2)}\n```"
    
    thinking_process = result.get("thinking", "No reasoning generated.")
    return verdict_output, thinking_process

# --- Training Logic ---
def start_training(epochs, batch_size, grad_accum):
    try:
        # Load data
        if DATASET_JSON_PATH.exists():
            dataset = get_dataset()
            data_source = "dataset.json"
        else:
            dataset = get_sample_dataset()
            data_source = "Hardcoded Sample Data"

        trainer = GreesyTrainer(
            model=model,
            train_dataset=dataset,
            batch_size=int(batch_size),
            grad_accum=int(grad_accum)
        )

        log_history = [f"Starting training on {DEVICE} using {data_source}..."]
        
        for epoch in range(1, int(epochs) + 1):
            avg_loss = trainer.train_epoch(epoch)
            log_history.append(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
            yield "\n".join(log_history)

        # Save weights
        torch.save(model.state_dict(), weights_path)
        log_history.append(f"Success: Weights saved to {weights_path}")
        yield "\n".join(log_history)

    except Exception as e:
        yield f"Error during training: {str(e)}"

# --- UI Layout ---
theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray")

with gr.Blocks() as demo:
    gr.Markdown("# 🛡️ GreesyGPT Control Center")
    
    with gr.Tabs():
        # TAB 1: MODERATION (INFERENCE)
        with gr.Tab("Moderation Interface"):
            with gr.Row():
                with gr.Column(scale=2):
                    input_text = gr.Textbox(label="Message to Review", lines=5)
                    with gr.Row():
                        mode_dropdown = gr.Dropdown(
                            choices=[m.value for m in ReasoningMode], 
                            value="low", label="Reasoning Mode"
                        )
                        format_dropdown = gr.Dropdown(
                            choices=[f.value for f in OutputFormat], 
                            value="markdown", label="Output Format"
                        )
                    submit_btn = gr.Button("Analyze Content", variant="primary")

                with gr.Column(scale=3):
                    output_verdict = gr.Markdown(label="Verdict")
                    with gr.Accordion("Internal Reasoning (Thinking)", open=False):
                        output_thinking = gr.Textbox(label="", interactive=False, lines=10)

        # TAB 2: TRAINING
        with gr.Tab("Model Training"):
            gr.Markdown("### Fine-tune GreesyGPT")
            with gr.Row():
                epoch_slider = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs")
                batch_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Batch Size")
                accum_slider = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Grad Accumulation")
            
            train_btn = gr.Button("🚀 Start Training Session", variant="stop")
            train_logs = gr.Textbox(label="Training Logs", interactive=False, lines=10)

        # TAB 3: SYSTEM INFO
        with gr.Tab("System Info"):
            gr.Markdown("### Reasoning Mode Definitions")
            gr.Code(describe_reasoning_modes(), language="markdown")
            status_msg = gr.Textbox(value=load_weights(), label="Model Status", interactive=False)

    # --- Event Handlers ---
    submit_btn.click(
        fn=moderate,
        inputs=[input_text, mode_dropdown, format_dropdown],
        outputs=[output_verdict, output_thinking]
    )

    train_btn.click(
        fn=start_training,
        inputs=[epoch_slider, batch_slider, accum_slider],
        outputs=[train_logs]
    )

    gr.Examples(
        examples=[
            ["You're so stupid, nobody likes you.", "medium", "markdown"],
            ["CONGRATULATIONS! You won a $1000 prize!", "low", "json"],
        ],
        inputs=[input_text, mode_dropdown, format_dropdown]
    )

if __name__ == "__main__":
    # In Gradio 6.0+, theme and title are passed here
    demo.launch(theme=theme)