import gradio as gr import torch import json import os from pathlib import Path from model import ( GreesyGPT, generate_moderation, GreesyTrainer, get_dataset, get_sample_dataset, ReasoningMode, OutputFormat, DEVICE, describe_reasoning_modes, DATASET_JSON_PATH ) # 1. Initialize Model Global Instance model = GreesyGPT() weights_path = Path("greesy_gpt.pt") def load_weights(): if weights_path.exists(): model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) return f"Loaded weights from {weights_path}" return "No weights found. Model initialized with random parameters." load_weights() model.to(DEVICE) # --- Inference Logic --- def moderate(text, mode_str, format_str): if not text.strip(): return "Please enter text.", "" model.eval() mode = ReasoningMode(mode_str.lower()) fmt = OutputFormat(format_str.lower()) result = generate_moderation(model, prompt=text, mode=mode, output_format=fmt) print(result) verdict_output = result["verdict_fmt"] if fmt == OutputFormat.JSON: verdict_output = f"```json\n{json.dumps(verdict_output, indent=2)}\n```" thinking_process = result.get("thinking", "No reasoning generated.") return verdict_output, thinking_process # --- Training Logic --- def start_training(epochs, batch_size, grad_accum): try: # Load data if DATASET_JSON_PATH.exists(): dataset = get_dataset() data_source = "dataset.json" else: dataset = get_sample_dataset() data_source = "Hardcoded Sample Data" trainer = GreesyTrainer( model=model, train_dataset=dataset, batch_size=int(batch_size), grad_accum=int(grad_accum) ) log_history = [f"Starting training on {DEVICE} using {data_source}..."] for epoch in range(1, int(epochs) + 1): avg_loss = trainer.train_epoch(epoch) log_history.append(f"Epoch {epoch}: Loss = {avg_loss:.4f}") yield "\n".join(log_history) # Save weights torch.save(model.state_dict(), weights_path) log_history.append(f"Success: Weights saved to {weights_path}") yield "\n".join(log_history) except Exception as e: yield f"Error during training: {str(e)}" # --- UI Layout --- theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray") with gr.Blocks() as demo: gr.Markdown("# 🛡️ GreesyGPT Control Center") with gr.Tabs(): # TAB 1: MODERATION (INFERENCE) with gr.Tab("Moderation Interface"): with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox(label="Message to Review", lines=5) with gr.Row(): mode_dropdown = gr.Dropdown( choices=[m.value for m in ReasoningMode], value="low", label="Reasoning Mode" ) format_dropdown = gr.Dropdown( choices=[f.value for f in OutputFormat], value="markdown", label="Output Format" ) submit_btn = gr.Button("Analyze Content", variant="primary") with gr.Column(scale=3): output_verdict = gr.Markdown(label="Verdict") with gr.Accordion("Internal Reasoning (Thinking)", open=False): output_thinking = gr.Textbox(label="", interactive=False, lines=10) # TAB 2: TRAINING with gr.Tab("Model Training"): gr.Markdown("### Fine-tune GreesyGPT") with gr.Row(): epoch_slider = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs") batch_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Batch Size") accum_slider = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Grad Accumulation") train_btn = gr.Button("🚀 Start Training Session", variant="stop") train_logs = gr.Textbox(label="Training Logs", interactive=False, lines=10) # TAB 3: SYSTEM INFO with gr.Tab("System Info"): gr.Markdown("### Reasoning Mode Definitions") gr.Code(describe_reasoning_modes(), language="markdown") status_msg = gr.Textbox(value=load_weights(), label="Model Status", interactive=False) # --- Event Handlers --- submit_btn.click( fn=moderate, inputs=[input_text, mode_dropdown, format_dropdown], outputs=[output_verdict, output_thinking] ) train_btn.click( fn=start_training, inputs=[epoch_slider, batch_slider, accum_slider], outputs=[train_logs] ) gr.Examples( examples=[ ["You're so stupid, nobody likes you.", "medium", "markdown"], ["CONGRATULATIONS! You won a $1000 prize!", "low", "json"], ], inputs=[input_text, mode_dropdown, format_dropdown] ) if __name__ == "__main__": # In Gradio 6.0+, theme and title are passed here demo.launch(theme=theme)