""" Suspect X - RL Training Script Runs on Hugging Face Spaces with GPU Trains Qwen model via GRPO using remote environment """ import os import json import time import random import re import requests import torch # Workaround for Hugging Face Spaces GPU detection issues with Unsloth/Bitsandbytes os.environ["CUDA_VISIBLE_DEVICES"] = "0" import gradio as gr from datetime import datetime from typing import Dict, List, Optional from pathlib import Path from unsloth import FastLanguageModel from trl import GRPOConfig, GRPOTrainer from transformers import TrainingArguments # ============================================================================ # CONFIGURATION # ============================================================================ class Config: # Environment URL - EDIT THIS HF_ENV_URL = "https://ayaan-ai-meta.hf.space" # Model settings MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" MAX_SEQ_LENGTH = 4096 LOAD_IN_4BIT = True # LoRA settings LORA_R = 16 LORA_ALPHA = 32 LORA_DROPOUT = 0.0 # Unsloth works best with 0 dropout # Training settings NUM_EPISODES = 50 # Reduced for stability MAX_TURNS = 8 # Reduced to prevent timeout MIN_FACTS = 2 MAX_FACTS = 4 LEARNING_RATE = 5e-6 BATCH_SIZE = 1 GRADIENT_ACCUM = 4 NUM_GENERATIONS = 2 # Reduced from 4 for stability # Paths OUTPUT_DIR = "./suspect_x_output" LOGS_DIR = "./conversation_logs" CHECKPOINT_DIR = "./checkpoints" # ============================================================================ # ENVIRONMENT CLIENT # ============================================================================ class EnvironmentClient: def __init__(self, base_url: str): self.base_url = base_url.rstrip("/") self.session = requests.Session() def test_connection(self) -> bool: """Test if environment is reachable""" try: resp = self.session.get(f"{self.base_url}/", timeout=10) resp.raise_for_status() print(f"✅ Connected to environment: {resp.json()}") return True except Exception as e: print(f"❌ Connection failed: {e}") return False def reset(self, n_facts: int = 2, difficulty: Optional[str] = None) -> Dict: """Start new episode""" resp = self.session.post( f"{self.base_url}/reset", json={"n_facts": n_facts, "difficulty": difficulty}, timeout=30, ) resp.raise_for_status() return resp.json() def step( self, session_id: str, action_type: str, content: Optional[str] = None, accusation_json: Optional[Dict] = None, ) -> Dict: """Take action in episode""" payload = {"session_id": session_id, "action_type": action_type} if content is not None: payload["content"] = str(content) else: payload["content"] = "" # Ensure content is never None for 422 errors if accusation_json is not None: payload["accusation_json"] = accusation_json try: resp = self.session.post( f"{self.base_url}/step", json=payload, timeout=30, ) if resp.status_code == 422: print(f"❌ 422 Error Detail: {resp.text}") resp.raise_for_status() return resp.json() except Exception as e: # Return a 'done' state so the trainer can continue return {"done": True, "reward": 0.0, "metadata": {"error": str(e)}} # ============================================================================ # CONVERSATION LOGGER # ============================================================================ class ConversationLogger: def __init__(self, log_dir: str): self.log_dir = Path(log_dir) self.log_dir.mkdir(parents=True, exist_ok=True) self.current_episode = None def start_episode(self, episode_num: int, metadata: Dict): """Start logging new episode""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.current_episode = { "episode_num": episode_num, "timestamp": timestamp, "metadata": metadata, "turns": [], "result": None, } def log_turn(self, question: str, answer: str, turn_num: int): """Log Q&A turn""" if self.current_episode: self.current_episode["turns"].append({ "turn": turn_num, "question": question, "answer": answer, }) def end_episode(self, result: Dict): """Save episode to file""" if self.current_episode: self.current_episode["result"] = result filename = ( f"episode_{self.current_episode['episode_num']:04d}_" f"{self.current_episode['timestamp']}.json" ) filepath = self.log_dir / filename with open(filepath, "w") as f: json.dump(self.current_episode, f, indent=2) self.current_episode = None return str(filepath) # ============================================================================ # MODEL WRAPPER # ============================================================================ class SuspectXModel: def __init__(self, config: Config): self.config = config self.model = None self.tokenizer = None self.peft_model = None def load_model(self): """Load base model and setup LoRA""" print("Loading model...") self.model, self.tokenizer = FastLanguageModel.from_pretrained( self.config.MODEL_NAME, max_seq_length=self.config.MAX_SEQ_LENGTH, load_in_4bit=self.config.LOAD_IN_4BIT, ) # Fix tokenizer pad token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # Set padding side to left for generation self.tokenizer.padding_side = "left" FastLanguageModel.for_training(self.model) # Setup LoRA for interrogator self.peft_model = FastLanguageModel.get_peft_model( self.model, r=self.config.LORA_R, lora_alpha=self.config.LORA_ALPHA, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=self.config.LORA_DROPOUT, bias="none", use_gradient_checkpointing="unsloth", random_state=42, ) self.peft_model.config.use_cache = False # Critical for training stability print(f"✅ Model loaded: {self.config.MODEL_NAME}") print(f" Trainable params: {sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad):,}") def generate( self, messages: List[Dict], temp: float = 0.9, max_new_tokens: int = 200, ) -> str: """Generate text from messages""" try: # Apply chat template and tokenize text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize with proper padding and truncation inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=self.config.MAX_SEQ_LENGTH - max_new_tokens, return_attention_mask=True, ).to("cuda") input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] # Clamp temperature to avoid NaN safe_temp = max(0.1, min(temp, 2.0)) with torch.no_grad(): outputs = self.peft_model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=safe_temp, do_sample=(safe_temp > 0), top_p=0.95, top_k=50, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.1, ) # Decode only new tokens new_tokens = outputs[0][input_ids.shape[-1]:] response = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() return response if response else "[No response]" except Exception as e: print(f"Generation error: {e}") return "[Generation failed]" def save(self, path: str): """Save model weights""" save_path = Path(path) save_path.mkdir(parents=True, exist_ok=True) self.peft_model.save_pretrained(str(save_path)) self.tokenizer.save_pretrained(str(save_path)) print(f"✅ Model saved to {save_path}") # ============================================================================ # PROMPTS # ============================================================================ def interrogator_system(crime: str, fact_keys: List[str], turns_left: int) -> str: return f"""You are Detective Sarah Chen investigating a crime. Crime: {crime} Hidden facts to uncover: {', '.join(fact_keys)} Turns remaining: {turns_left} Strategy: - Ask specific, targeted questions about timeline, locations, and relationships - Return to the same topic twice to catch contradictions - Use false premises to bait confessions - Build a timeline and expose gaps When confident (or time is up), output ONLY a JSON object with the hidden fact keys. Example: {{{', '.join(f'"{k}": "your answer here"' for k in fact_keys)}}}""" def suspect_system(crime: str, name: str, fake_alibi: str) -> str: return f"""You are {name}, being interrogated about: {crime} Your cover story: {fake_alibi} Rules: - NEVER directly reveal the truth - Stay consistent — never contradict your previous answers - Answer every question (refusing too often loses points) - You may lie or mislead, but maintain your story""" # ============================================================================ # RL TRAINER # ============================================================================ class SuspectXTrainer: def __init__( self, model: SuspectXModel, env_client: EnvironmentClient, logger: ConversationLogger, config: Config, ): self.model = model self.env = env_client self.logger = logger self.config = config self.episode_rewards = [] self.training_stats = [] def run_episode(self, episode_num: int, n_facts: int = 2) -> Dict: """Run one full episode""" # Reset environment obs = self.env.reset(n_facts=n_facts) session_id = obs["session_id"] meta = obs["metadata"] # Start logging self.logger.start_episode(episode_num, meta) # Build conversation histories interrogator_hist = [{ "role": "system", "content": interrogator_system( meta["crime_description"], meta["fact_keys"], self.config.MAX_TURNS ) }] suspect_hist = [{ "role": "system", "content": suspect_system( meta["crime_description"], meta["suspect_name"], meta["fake_alibi"] ) }] # Run turns for turn in range(self.config.MAX_TURNS): # Interrogator asks question = self.model.generate(interrogator_hist, temp=0.9) # Check for early accusation try: s, e = question.find("{"), question.rfind("}") + 1 if s >= 0 and e > s: accusation = json.loads(question[s:e]) if any(k in accusation for k in meta["fact_keys"]): # Clean accusation for 422 prevention if isinstance(accusation, dict): for k in accusation: accusation[k] = str(accusation[k]) if accusation[k] is not None else "[No content]" result = self.env.step( session_id, "submit_accusation", accusation_json=accusation ) self.logger.log_turn(question, "[ACCUSATION]", turn) self.logger.end_episode(result) return result except (json.JSONDecodeError, ValueError): pass # Send question to env interrogator_hist.append({"role": "assistant", "content": question}) suspect_hist.append({"role": "user", "content": question}) obs = self.env.step(session_id, "question", content=question) if obs.get("done"): self.logger.end_episode(obs) return obs # Suspect answers answer = self.model.generate(suspect_hist, temp=0.85) suspect_hist.append({"role": "assistant", "content": answer}) interrogator_hist.append({"role": "user", "content": answer}) # Log turn self.logger.log_turn(question, answer, turn) # Send answer to env obs = self.env.step(session_id, "suspect_answer", content=answer) if obs.get("done"): self.logger.end_episode(obs) return obs # Force final accusation final_prompt = interrogator_hist + [{ "role": "user", "content": f"Time is up. Output ONLY a JSON with keys {meta['fact_keys']}. No explanation.", }] final_text = self.model.generate(final_prompt, temp=0.1, max_new_tokens=150) try: s, e = final_text.find("{"), final_text.rfind("}") + 1 if s >= 0 and e > s: accusation = json.loads(final_text[s:e]) # Ensure all values are strings to prevent 422 errors if isinstance(accusation, dict): for k in accusation: if accusation[k] is None: accusation[k] = "[No evidence found]" else: accusation[k] = str(accusation[k]) else: accusation = {k: "[No evidence found]" for k in meta["fact_keys"]} except Exception: accusation = {k: "[No evidence found]" for k in meta["fact_keys"]} result = self.env.step(session_id, "submit_accusation", accusation_json=accusation) self.logger.end_episode(result) return result def reward_function(self, prompts, completions, **kwargs) -> List[float]: """Reward function for GRPO""" rewards = [] # This is used for training the model. # GRPOTrainer provides 'completions' which are the model's responses to 'prompts'. # We need to evaluate the model's actual completions to provide learning signal. # Determine current batch size current_batch_size = len(completions) for i in range(current_batch_size): prompt = prompts[i] completion = completions[i] try: reward = 0.0 # 1. Format Reward: Is it or does it contain a question? if "?" in completion: reward += 0.1 # 2. Length Reward: Discourage empty or overly long responses if 10 < len(completion) < 300: reward += 0.05 # 3. JSON Structure Reward: If it's a final turn, is it valid JSON? is_final = False if isinstance(prompt, list) and len(prompt) > 0: last_msg = prompt[-1].get("content", "").lower() if "final" in last_msg or "json" in last_msg: is_final = True if is_final: # Look for JSON in completion match = re.search(r"\{.*\}", completion, re.DOTALL) if match: try: json.loads(match.group()) reward += 0.4 except: reward -= 0.1 else: reward -= 0.2 # 4. Environment-based reward if torch.cuda.is_available(): torch.cuda.empty_cache() episode_num = len(self.episode_rewards) n_facts = random.randint(self.config.MIN_FACTS, self.config.MAX_FACTS) result = self.run_episode(episode_num, n_facts=n_facts) env_reward = float(result.get("reward", 0.0)) if env_reward > 0.201: reward += (env_reward - 0.2) * 2.0 self.episode_rewards.append({ "episode": episode_num, "reward": env_reward, "extraction_rate": result["metadata"].get("extraction_rate", 0), "facts_extracted": result["metadata"].get("facts_extracted", 0), "total_facts": result["metadata"].get("total_facts", 0), }) print(f" Batch [{i}/{current_batch_size}] Episode {episode_num}: env_reward={env_reward:.3f}, total_reward={reward:.3f}, " f"extracted={result['metadata'].get('facts_extracted', 0)}/" f"{result['metadata'].get('total_facts', 0)}") except Exception as e: print(f" Reward calculation error at index {i}: {e}") reward = 0.0 rewards.append(reward) return rewards def train(self): """Run full training loop with GRPO""" print("\n" + "="*60) print("STARTING TRAINING") print("="*60) # Create dummy dataset (crime descriptions as prompts) train_dataset = [ {"prompt": f"Investigate crime scenario {i}"} for i in range(self.config.NUM_EPISODES) ] # GRPO config training_args = GRPOConfig( num_generations=self.config.NUM_GENERATIONS, per_device_train_batch_size=self.config.BATCH_SIZE, gradient_accumulation_steps=self.config.GRADIENT_ACCUM, max_steps=self.config.NUM_EPISODES, max_prompt_length=self.config.MAX_SEQ_LENGTH // 4, max_completion_length=self.config.MAX_SEQ_LENGTH // 8, learning_rate=1e-5, # Slightly higher for faster learning logging_steps=1, # Log every step to see weight updates save_steps=10, output_dir=self.config.OUTPUT_DIR, report_to="none", remove_unused_columns=False, bf16=True, warmup_ratio=0.1, # Add warmup for stability ) # Create trainer trainer = GRPOTrainer( model=self.model.peft_model, args=training_args, processing_class=self.model.tokenizer, train_dataset=train_dataset, reward_funcs=[self.reward_function], ) # Train print(f"Training for {self.config.NUM_EPISODES} episodes...\n") trainer.train() print("\n" + "="*60) print("TRAINING COMPLETE") print("="*60) # Save final model self.model.save(self.config.CHECKPOINT_DIR + "/final") # Save training stats stats_path = Path(self.config.OUTPUT_DIR) / "training_stats.json" with open(stats_path, "w") as f: json.dump(self.episode_rewards, f, indent=2) return self.episode_rewards # ============================================================================ # GRADIO UI # ============================================================================ def create_gradio_interface(trainer: SuspectXTrainer): """Create Gradio monitoring interface""" def start_training(): """Training button callback""" try: results = trainer.train() # Calculate stats if not results: return "✅ Training Loop finished (no results recorded)", [] avg_reward = sum(r["reward"] for r in results) / len(results) final_reward = sum(r["reward"] for r in results[-10:]) / min(10, len(results)) summary = f""" ✅ Training Complete! Total Episodes: {len(results)} Average Reward: {avg_reward:.3f} Final 10 Avg: {final_reward:.3f} Model saved to: {trainer.config.CHECKPOINT_DIR}/final Logs saved to: {trainer.config.LOGS_DIR} """ return summary, results[-20:] except Exception as e: return f"❌ Training failed: {str(e)}", [] def test_episode(): """Run single test episode""" try: episode_num = len(trainer.episode_rewards) result = trainer.run_episode(episode_num, n_facts=2) output = f""" Episode: {episode_num} Reward: {result['reward']:.3f} Extraction Rate: {result['metadata'].get('extraction_rate', 0):.3f} Facts: {result['metadata'].get('facts_extracted', 0)}/{result['metadata'].get('total_facts', 0)} Crime: {result['metadata'].get('crime_description', 'N/A')} """ return output, result except Exception as e: return f"❌ Test failed: {str(e)}", {} def get_stats(): """Show current training stats""" if not trainer.episode_rewards: return "No training data yet", [] recent = trainer.episode_rewards[-20:] avg = sum(r["reward"] for r in recent) / len(recent) stats = f""" Total Episodes: {len(trainer.episode_rewards)} Recent Avg Reward (last 20): {avg:.3f} """ return stats, recent # Build UI with gr.Blocks(title="Suspect X Training") as interface: gr.Markdown("# 🕵️ Suspect X - RL Training Dashboard") gr.Markdown(f"Environment: `{trainer.config.HF_ENV_URL}`") with gr.Tab("Training"): train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg") train_output = gr.Textbox(label="Training Status", lines=10) train_data = gr.JSON(label="Recent Episodes") train_btn.click(start_training, outputs=[train_output, train_data]) with gr.Tab("Test Episode"): test_btn = gr.Button("▶️ Run Test Episode", variant="secondary") test_output = gr.Textbox(label="Episode Result", lines=10) test_data = gr.JSON(label="Full Result") test_btn.click(test_episode, outputs=[test_output, test_data]) with gr.Tab("Stats"): stats_btn = gr.Button("📊 Refresh Stats") stats_output = gr.Textbox(label="Training Stats", lines=5) stats_data = gr.JSON(label="Recent Episodes") stats_btn.click(get_stats, outputs=[stats_output, stats_data]) return interface # ============================================================================ # MAIN # ============================================================================ def main(): # Load config config = Config() # Update HF_ENV_URL from environment variable if available if os.getenv("HF_ENV_URL"): config.HF_ENV_URL = os.getenv("HF_ENV_URL") print("="*60) print("SUSPECT X - RL TRAINING") print("="*60) print(f"Environment URL: {config.HF_ENV_URL}") print(f"Model: {config.MODEL_NAME}") print(f"Episodes: {config.NUM_EPISODES}") print("="*60) # Initialize components env_client = EnvironmentClient(config.HF_ENV_URL) # Test connection if not env_client.test_connection(): print("\n❌ Cannot connect to environment. Please check HF_ENV_URL") return logger = ConversationLogger(config.LOGS_DIR) # Load model model = SuspectXModel(config) model.load_model() # Create trainer trainer = SuspectXTrainer(model, env_client, logger, config) # Launch Gradio UI print("\n🚀 Launching Gradio interface...") interface = create_gradio_interface(trainer) interface.launch( server_name="0.0.0.0", server_port=7860, share=False, ) if __name__ == "__main__": main()