Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import json | |
| import os | |
| from pathlib import Path | |
| from model import ( | |
| GreesyGPT, | |
| generate_moderation, | |
| GreesyTrainer, | |
| get_dataset, | |
| get_sample_dataset, | |
| ReasoningMode, | |
| OutputFormat, | |
| DEVICE, | |
| describe_reasoning_modes, | |
| DATASET_JSON_PATH | |
| ) | |
| # 1. Initialize Model Global Instance | |
| model = GreesyGPT() | |
| weights_path = Path("greesy_gpt.pt") | |
| def load_weights(): | |
| if weights_path.exists(): | |
| model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) | |
| return f"Loaded weights from {weights_path}" | |
| return "No weights found. Model initialized with random parameters." | |
| load_weights() | |
| model.to(DEVICE) | |
| # --- Inference Logic --- | |
| def moderate(text, mode_str, format_str): | |
| if not text.strip(): | |
| return "Please enter text.", "" | |
| model.eval() | |
| mode = ReasoningMode(mode_str.lower()) | |
| fmt = OutputFormat(format_str.lower()) | |
| result = generate_moderation(model, prompt=text, mode=mode, output_format=fmt) | |
| print(result) | |
| verdict_output = result["verdict_fmt"] | |
| if fmt == OutputFormat.JSON: | |
| verdict_output = f"```json\n{json.dumps(verdict_output, indent=2)}\n```" | |
| thinking_process = result.get("thinking", "No reasoning generated.") | |
| return verdict_output, thinking_process | |
| # --- Training Logic --- | |
| def start_training(epochs, batch_size, grad_accum): | |
| try: | |
| # Load data | |
| if DATASET_JSON_PATH.exists(): | |
| dataset = get_dataset() | |
| data_source = "dataset.json" | |
| else: | |
| dataset = get_sample_dataset() | |
| data_source = "Hardcoded Sample Data" | |
| trainer = GreesyTrainer( | |
| model=model, | |
| train_dataset=dataset, | |
| batch_size=int(batch_size), | |
| grad_accum=int(grad_accum) | |
| ) | |
| log_history = [f"Starting training on {DEVICE} using {data_source}..."] | |
| for epoch in range(1, int(epochs) + 1): | |
| avg_loss = trainer.train_epoch(epoch) | |
| log_history.append(f"Epoch {epoch}: Loss = {avg_loss:.4f}") | |
| yield "\n".join(log_history) | |
| # Save weights | |
| torch.save(model.state_dict(), weights_path) | |
| log_history.append(f"Success: Weights saved to {weights_path}") | |
| yield "\n".join(log_history) | |
| except Exception as e: | |
| yield f"Error during training: {str(e)}" | |
| # --- UI Layout --- | |
| theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🛡️ GreesyGPT Control Center") | |
| with gr.Tabs(): | |
| # TAB 1: MODERATION (INFERENCE) | |
| with gr.Tab("Moderation Interface"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.Textbox(label="Message to Review", lines=5) | |
| with gr.Row(): | |
| mode_dropdown = gr.Dropdown( | |
| choices=[m.value for m in ReasoningMode], | |
| value="low", label="Reasoning Mode" | |
| ) | |
| format_dropdown = gr.Dropdown( | |
| choices=[f.value for f in OutputFormat], | |
| value="markdown", label="Output Format" | |
| ) | |
| submit_btn = gr.Button("Analyze Content", variant="primary") | |
| with gr.Column(scale=3): | |
| output_verdict = gr.Markdown(label="Verdict") | |
| with gr.Accordion("Internal Reasoning (Thinking)", open=False): | |
| output_thinking = gr.Textbox(label="", interactive=False, lines=10) | |
| # TAB 2: TRAINING | |
| with gr.Tab("Model Training"): | |
| gr.Markdown("### Fine-tune GreesyGPT") | |
| with gr.Row(): | |
| epoch_slider = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs") | |
| batch_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Batch Size") | |
| accum_slider = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Grad Accumulation") | |
| train_btn = gr.Button("🚀 Start Training Session", variant="stop") | |
| train_logs = gr.Textbox(label="Training Logs", interactive=False, lines=10) | |
| # TAB 3: SYSTEM INFO | |
| with gr.Tab("System Info"): | |
| gr.Markdown("### Reasoning Mode Definitions") | |
| gr.Code(describe_reasoning_modes(), language="markdown") | |
| status_msg = gr.Textbox(value=load_weights(), label="Model Status", interactive=False) | |
| # --- Event Handlers --- | |
| submit_btn.click( | |
| fn=moderate, | |
| inputs=[input_text, mode_dropdown, format_dropdown], | |
| outputs=[output_verdict, output_thinking] | |
| ) | |
| train_btn.click( | |
| fn=start_training, | |
| inputs=[epoch_slider, batch_slider, accum_slider], | |
| outputs=[train_logs] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["You're so stupid, nobody likes you.", "medium", "markdown"], | |
| ["CONGRATULATIONS! You won a $1000 prize!", "low", "json"], | |
| ], | |
| inputs=[input_text, mode_dropdown, format_dropdown] | |
| ) | |
| if __name__ == "__main__": | |
| # In Gradio 6.0+, theme and title are passed here | |
| demo.launch(theme=theme) |