testing / app.py
OnlyCheeini's picture
Update app.py
f806ce7 verified
import gradio as gr
import torch
import json
import os
from pathlib import Path
from model import (
GreesyGPT,
generate_moderation,
GreesyTrainer,
get_dataset,
get_sample_dataset,
ReasoningMode,
OutputFormat,
DEVICE,
describe_reasoning_modes,
DATASET_JSON_PATH
)
# 1. Initialize Model Global Instance
model = GreesyGPT()
weights_path = Path("greesy_gpt.pt")
def load_weights():
if weights_path.exists():
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
return f"Loaded weights from {weights_path}"
return "No weights found. Model initialized with random parameters."
load_weights()
model.to(DEVICE)
# --- Inference Logic ---
def moderate(text, mode_str, format_str):
if not text.strip():
return "Please enter text.", ""
model.eval()
mode = ReasoningMode(mode_str.lower())
fmt = OutputFormat(format_str.lower())
result = generate_moderation(model, prompt=text, mode=mode, output_format=fmt)
print(result)
verdict_output = result["verdict_fmt"]
if fmt == OutputFormat.JSON:
verdict_output = f"```json\n{json.dumps(verdict_output, indent=2)}\n```"
thinking_process = result.get("thinking", "No reasoning generated.")
return verdict_output, thinking_process
# --- Training Logic ---
def start_training(epochs, batch_size, grad_accum):
try:
# Load data
if DATASET_JSON_PATH.exists():
dataset = get_dataset()
data_source = "dataset.json"
else:
dataset = get_sample_dataset()
data_source = "Hardcoded Sample Data"
trainer = GreesyTrainer(
model=model,
train_dataset=dataset,
batch_size=int(batch_size),
grad_accum=int(grad_accum)
)
log_history = [f"Starting training on {DEVICE} using {data_source}..."]
for epoch in range(1, int(epochs) + 1):
avg_loss = trainer.train_epoch(epoch)
log_history.append(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
yield "\n".join(log_history)
# Save weights
torch.save(model.state_dict(), weights_path)
log_history.append(f"Success: Weights saved to {weights_path}")
yield "\n".join(log_history)
except Exception as e:
yield f"Error during training: {str(e)}"
# --- UI Layout ---
theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray")
with gr.Blocks() as demo:
gr.Markdown("# 🛡️ GreesyGPT Control Center")
with gr.Tabs():
# TAB 1: MODERATION (INFERENCE)
with gr.Tab("Moderation Interface"):
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(label="Message to Review", lines=5)
with gr.Row():
mode_dropdown = gr.Dropdown(
choices=[m.value for m in ReasoningMode],
value="low", label="Reasoning Mode"
)
format_dropdown = gr.Dropdown(
choices=[f.value for f in OutputFormat],
value="markdown", label="Output Format"
)
submit_btn = gr.Button("Analyze Content", variant="primary")
with gr.Column(scale=3):
output_verdict = gr.Markdown(label="Verdict")
with gr.Accordion("Internal Reasoning (Thinking)", open=False):
output_thinking = gr.Textbox(label="", interactive=False, lines=10)
# TAB 2: TRAINING
with gr.Tab("Model Training"):
gr.Markdown("### Fine-tune GreesyGPT")
with gr.Row():
epoch_slider = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs")
batch_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Batch Size")
accum_slider = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Grad Accumulation")
train_btn = gr.Button("🚀 Start Training Session", variant="stop")
train_logs = gr.Textbox(label="Training Logs", interactive=False, lines=10)
# TAB 3: SYSTEM INFO
with gr.Tab("System Info"):
gr.Markdown("### Reasoning Mode Definitions")
gr.Code(describe_reasoning_modes(), language="markdown")
status_msg = gr.Textbox(value=load_weights(), label="Model Status", interactive=False)
# --- Event Handlers ---
submit_btn.click(
fn=moderate,
inputs=[input_text, mode_dropdown, format_dropdown],
outputs=[output_verdict, output_thinking]
)
train_btn.click(
fn=start_training,
inputs=[epoch_slider, batch_slider, accum_slider],
outputs=[train_logs]
)
gr.Examples(
examples=[
["You're so stupid, nobody likes you.", "medium", "markdown"],
["CONGRATULATIONS! You won a $1000 prize!", "low", "json"],
],
inputs=[input_text, mode_dropdown, format_dropdown]
)
if __name__ == "__main__":
# In Gradio 6.0+, theme and title are passed here
demo.launch(theme=theme)