import json import time from pathlib import Path import torch from model import ( GreesyGPT, generate_moderation, ReasoningMode, OutputFormat, DEVICE, ) # ───────────────────────────────────────────── # Model Initialization # ───────────────────────────────────────────── model = GreesyGPT() weights_path = Path(__file__).parent / "greesy_gpt.pt" if weights_path.exists(): model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) print(f"Loaded weights from {weights_path}") else: print("No trained weights found, using fresh initialization.") model.to(DEVICE) model.eval() # ───────────────────────────────────────────── # OpenAI‑style Chat Completion Wrapper # ───────────────────────────────────────────── def chat_completions( model: GreesyGPT, messages, reasoning_mode: ReasoningMode = ReasoningMode.LOW, ): """ Emulates the OpenAI Chat Completions API format. Input: messages = [ {"role": "system", "content": "..."}, {"role": "user", "content": "..."} ] Output: { "id": "...", "object": "chat.completion", "choices": [...] } """ user_message = "" system_message = "" for m in messages: if m["role"] == "user": user_message = m["content"] elif m["role"] == "system": system_message = m["content"] result = generate_moderation( model, prompt=user_message, mode=reasoning_mode, output_format=OutputFormat.MARKDOWN, ) verdict = result["verdict"] thinking = result["thinking"] completion_text = verdict response = { "id": f"greesy-{int(time.time()*1000)}", "object": "chat.completion", "created": int(time.time()), "model": "latest", "choices": [ { "index": 0, "message": { "role": "assistant", "content": completion_text, }, "finish_reason": "stop", } ], "metadata": { "reasoning": thinking, }, } return response # ───────────────────────────────────────────── # Example Usage # ───────────────────────────────────────────── if __name__ == "__main__": messages = [ {"role": "system", "content": "You are a moderation system."}, {"role": "user", "content": "You're so stupid, nobody likes you."}, ] response = chat_completions( model, messages, reasoning_mode=ReasoningMode.MEDIUM, ) print(json.dumps(response, indent=2))