Spaces:
Running
Running
File size: 5,273 Bytes
b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 f806ce7 b86ba0c 709e8f1 b86ba0c 51356d9 709e8f1 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 b86ba0c 51356d9 0657516 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import gradio as gr
import torch
import json
import os
from pathlib import Path
from model import (
GreesyGPT,
generate_moderation,
GreesyTrainer,
get_dataset,
get_sample_dataset,
ReasoningMode,
OutputFormat,
DEVICE,
describe_reasoning_modes,
DATASET_JSON_PATH
)
# 1. Initialize Model Global Instance
model = GreesyGPT()
weights_path = Path("greesy_gpt.pt")
def load_weights():
if weights_path.exists():
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
return f"Loaded weights from {weights_path}"
return "No weights found. Model initialized with random parameters."
load_weights()
model.to(DEVICE)
# --- Inference Logic ---
def moderate(text, mode_str, format_str):
if not text.strip():
return "Please enter text.", ""
model.eval()
mode = ReasoningMode(mode_str.lower())
fmt = OutputFormat(format_str.lower())
result = generate_moderation(model, prompt=text, mode=mode, output_format=fmt)
print(result)
verdict_output = result["verdict_fmt"]
if fmt == OutputFormat.JSON:
verdict_output = f"```json\n{json.dumps(verdict_output, indent=2)}\n```"
thinking_process = result.get("thinking", "No reasoning generated.")
return verdict_output, thinking_process
# --- Training Logic ---
def start_training(epochs, batch_size, grad_accum):
try:
# Load data
if DATASET_JSON_PATH.exists():
dataset = get_dataset()
data_source = "dataset.json"
else:
dataset = get_sample_dataset()
data_source = "Hardcoded Sample Data"
trainer = GreesyTrainer(
model=model,
train_dataset=dataset,
batch_size=int(batch_size),
grad_accum=int(grad_accum)
)
log_history = [f"Starting training on {DEVICE} using {data_source}..."]
for epoch in range(1, int(epochs) + 1):
avg_loss = trainer.train_epoch(epoch)
log_history.append(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
yield "\n".join(log_history)
# Save weights
torch.save(model.state_dict(), weights_path)
log_history.append(f"Success: Weights saved to {weights_path}")
yield "\n".join(log_history)
except Exception as e:
yield f"Error during training: {str(e)}"
# --- UI Layout ---
theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray")
with gr.Blocks() as demo:
gr.Markdown("# 🛡️ GreesyGPT Control Center")
with gr.Tabs():
# TAB 1: MODERATION (INFERENCE)
with gr.Tab("Moderation Interface"):
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(label="Message to Review", lines=5)
with gr.Row():
mode_dropdown = gr.Dropdown(
choices=[m.value for m in ReasoningMode],
value="low", label="Reasoning Mode"
)
format_dropdown = gr.Dropdown(
choices=[f.value for f in OutputFormat],
value="markdown", label="Output Format"
)
submit_btn = gr.Button("Analyze Content", variant="primary")
with gr.Column(scale=3):
output_verdict = gr.Markdown(label="Verdict")
with gr.Accordion("Internal Reasoning (Thinking)", open=False):
output_thinking = gr.Textbox(label="", interactive=False, lines=10)
# TAB 2: TRAINING
with gr.Tab("Model Training"):
gr.Markdown("### Fine-tune GreesyGPT")
with gr.Row():
epoch_slider = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs")
batch_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Batch Size")
accum_slider = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Grad Accumulation")
train_btn = gr.Button("🚀 Start Training Session", variant="stop")
train_logs = gr.Textbox(label="Training Logs", interactive=False, lines=10)
# TAB 3: SYSTEM INFO
with gr.Tab("System Info"):
gr.Markdown("### Reasoning Mode Definitions")
gr.Code(describe_reasoning_modes(), language="markdown")
status_msg = gr.Textbox(value=load_weights(), label="Model Status", interactive=False)
# --- Event Handlers ---
submit_btn.click(
fn=moderate,
inputs=[input_text, mode_dropdown, format_dropdown],
outputs=[output_verdict, output_thinking]
)
train_btn.click(
fn=start_training,
inputs=[epoch_slider, batch_slider, accum_slider],
outputs=[train_logs]
)
gr.Examples(
examples=[
["You're so stupid, nobody likes you.", "medium", "markdown"],
["CONGRATULATIONS! You won a $1000 prize!", "low", "json"],
],
inputs=[input_text, mode_dropdown, format_dropdown]
)
if __name__ == "__main__":
# In Gradio 6.0+, theme and title are passed here
demo.launch(theme=theme) |