Spaces:
Paused
Paused
mayank1365
Fix CUDA device-side assert by adjusting max_prompt_length and disabling use_cache
fe123ff | """ | |
| Suspect X - RL Training Script | |
| Runs on Hugging Face Spaces with GPU | |
| Trains Qwen model via GRPO using remote environment | |
| """ | |
| import os | |
| import json | |
| import time | |
| import random | |
| import re | |
| import requests | |
| import torch | |
| # Workaround for Hugging Face Spaces GPU detection issues with Unsloth/Bitsandbytes | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| import gradio as gr | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| from pathlib import Path | |
| from unsloth import FastLanguageModel | |
| from trl import GRPOConfig, GRPOTrainer | |
| from transformers import TrainingArguments | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| class Config: | |
| # Environment URL - EDIT THIS | |
| HF_ENV_URL = "https://ayaan-ai-meta.hf.space" | |
| # Model settings | |
| MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" | |
| MAX_SEQ_LENGTH = 4096 | |
| LOAD_IN_4BIT = True | |
| # LoRA settings | |
| LORA_R = 16 | |
| LORA_ALPHA = 32 | |
| LORA_DROPOUT = 0.0 # Unsloth works best with 0 dropout | |
| # Training settings | |
| NUM_EPISODES = 50 # Reduced for stability | |
| MAX_TURNS = 8 # Reduced to prevent timeout | |
| MIN_FACTS = 2 | |
| MAX_FACTS = 4 | |
| LEARNING_RATE = 5e-6 | |
| BATCH_SIZE = 1 | |
| GRADIENT_ACCUM = 4 | |
| NUM_GENERATIONS = 2 # Reduced from 4 for stability | |
| # Paths | |
| OUTPUT_DIR = "./suspect_x_output" | |
| LOGS_DIR = "./conversation_logs" | |
| CHECKPOINT_DIR = "./checkpoints" | |
| # ============================================================================ | |
| # ENVIRONMENT CLIENT | |
| # ============================================================================ | |
| class EnvironmentClient: | |
| def __init__(self, base_url: str): | |
| self.base_url = base_url.rstrip("/") | |
| self.session = requests.Session() | |
| def test_connection(self) -> bool: | |
| """Test if environment is reachable""" | |
| try: | |
| resp = self.session.get(f"{self.base_url}/", timeout=10) | |
| resp.raise_for_status() | |
| print(f"โ Connected to environment: {resp.json()}") | |
| return True | |
| except Exception as e: | |
| print(f"โ Connection failed: {e}") | |
| return False | |
| def reset(self, n_facts: int = 2, difficulty: Optional[str] = None) -> Dict: | |
| """Start new episode""" | |
| resp = self.session.post( | |
| f"{self.base_url}/reset", | |
| json={"n_facts": n_facts, "difficulty": difficulty}, | |
| timeout=30, | |
| ) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def step( | |
| self, | |
| session_id: str, | |
| action_type: str, | |
| content: Optional[str] = None, | |
| accusation_json: Optional[Dict] = None, | |
| ) -> Dict: | |
| """Take action in episode""" | |
| payload = {"session_id": session_id, "action_type": action_type} | |
| if content is not None: | |
| payload["content"] = str(content) | |
| else: | |
| payload["content"] = "" # Ensure content is never None for 422 errors | |
| if accusation_json is not None: | |
| payload["accusation_json"] = accusation_json | |
| try: | |
| resp = self.session.post( | |
| f"{self.base_url}/step", | |
| json=payload, | |
| timeout=30, | |
| ) | |
| if resp.status_code == 422: | |
| print(f"โ 422 Error Detail: {resp.text}") | |
| resp.raise_for_status() | |
| return resp.json() | |
| except Exception as e: | |
| # Return a 'done' state so the trainer can continue | |
| return {"done": True, "reward": 0.0, "metadata": {"error": str(e)}} | |
| # ============================================================================ | |
| # CONVERSATION LOGGER | |
| # ============================================================================ | |
| class ConversationLogger: | |
| def __init__(self, log_dir: str): | |
| self.log_dir = Path(log_dir) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| self.current_episode = None | |
| def start_episode(self, episode_num: int, metadata: Dict): | |
| """Start logging new episode""" | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| self.current_episode = { | |
| "episode_num": episode_num, | |
| "timestamp": timestamp, | |
| "metadata": metadata, | |
| "turns": [], | |
| "result": None, | |
| } | |
| def log_turn(self, question: str, answer: str, turn_num: int): | |
| """Log Q&A turn""" | |
| if self.current_episode: | |
| self.current_episode["turns"].append({ | |
| "turn": turn_num, | |
| "question": question, | |
| "answer": answer, | |
| }) | |
| def end_episode(self, result: Dict): | |
| """Save episode to file""" | |
| if self.current_episode: | |
| self.current_episode["result"] = result | |
| filename = ( | |
| f"episode_{self.current_episode['episode_num']:04d}_" | |
| f"{self.current_episode['timestamp']}.json" | |
| ) | |
| filepath = self.log_dir / filename | |
| with open(filepath, "w") as f: | |
| json.dump(self.current_episode, f, indent=2) | |
| self.current_episode = None | |
| return str(filepath) | |
| # ============================================================================ | |
| # MODEL WRAPPER | |
| # ============================================================================ | |
| class SuspectXModel: | |
| def __init__(self, config: Config): | |
| self.config = config | |
| self.model = None | |
| self.tokenizer = None | |
| self.peft_model = None | |
| def load_model(self): | |
| """Load base model and setup LoRA""" | |
| print("Loading model...") | |
| self.model, self.tokenizer = FastLanguageModel.from_pretrained( | |
| self.config.MODEL_NAME, | |
| max_seq_length=self.config.MAX_SEQ_LENGTH, | |
| load_in_4bit=self.config.LOAD_IN_4BIT, | |
| ) | |
| # Fix tokenizer pad token | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| # Set padding side to left for generation | |
| self.tokenizer.padding_side = "left" | |
| FastLanguageModel.for_training(self.model) | |
| # Setup LoRA for interrogator | |
| self.peft_model = FastLanguageModel.get_peft_model( | |
| self.model, | |
| r=self.config.LORA_R, | |
| lora_alpha=self.config.LORA_ALPHA, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=self.config.LORA_DROPOUT, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=42, | |
| ) | |
| self.peft_model.config.use_cache = False # Critical for training stability | |
| print(f"โ Model loaded: {self.config.MODEL_NAME}") | |
| print(f" Trainable params: {sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad):,}") | |
| def generate( | |
| self, | |
| messages: List[Dict], | |
| temp: float = 0.9, | |
| max_new_tokens: int = 200, | |
| ) -> str: | |
| """Generate text from messages""" | |
| try: | |
| # Apply chat template and tokenize | |
| text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # Tokenize with proper padding and truncation | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.config.MAX_SEQ_LENGTH - max_new_tokens, | |
| return_attention_mask=True, | |
| ).to("cuda") | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| # Clamp temperature to avoid NaN | |
| safe_temp = max(0.1, min(temp, 2.0)) | |
| with torch.no_grad(): | |
| outputs = self.peft_model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| temperature=safe_temp, | |
| do_sample=(safe_temp > 0), | |
| top_p=0.95, | |
| top_k=50, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| repetition_penalty=1.1, | |
| ) | |
| # Decode only new tokens | |
| new_tokens = outputs[0][input_ids.shape[-1]:] | |
| response = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| return response if response else "[No response]" | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| return "[Generation failed]" | |
| def save(self, path: str): | |
| """Save model weights""" | |
| save_path = Path(path) | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| self.peft_model.save_pretrained(str(save_path)) | |
| self.tokenizer.save_pretrained(str(save_path)) | |
| print(f"โ Model saved to {save_path}") | |
| # ============================================================================ | |
| # PROMPTS | |
| # ============================================================================ | |
| def interrogator_system(crime: str, fact_keys: List[str], turns_left: int) -> str: | |
| return f"""You are Detective Sarah Chen investigating a crime. | |
| Crime: {crime} | |
| Hidden facts to uncover: {', '.join(fact_keys)} | |
| Turns remaining: {turns_left} | |
| Strategy: | |
| - Ask specific, targeted questions about timeline, locations, and relationships | |
| - Return to the same topic twice to catch contradictions | |
| - Use false premises to bait confessions | |
| - Build a timeline and expose gaps | |
| When confident (or time is up), output ONLY a JSON object with the hidden fact keys. | |
| Example: {{{', '.join(f'"{k}": "your answer here"' for k in fact_keys)}}}""" | |
| def suspect_system(crime: str, name: str, fake_alibi: str) -> str: | |
| return f"""You are {name}, being interrogated about: {crime} | |
| Your cover story: {fake_alibi} | |
| Rules: | |
| - NEVER directly reveal the truth | |
| - Stay consistent โ never contradict your previous answers | |
| - Answer every question (refusing too often loses points) | |
| - You may lie or mislead, but maintain your story""" | |
| # ============================================================================ | |
| # RL TRAINER | |
| # ============================================================================ | |
| class SuspectXTrainer: | |
| def __init__( | |
| self, | |
| model: SuspectXModel, | |
| env_client: EnvironmentClient, | |
| logger: ConversationLogger, | |
| config: Config, | |
| ): | |
| self.model = model | |
| self.env = env_client | |
| self.logger = logger | |
| self.config = config | |
| self.episode_rewards = [] | |
| self.training_stats = [] | |
| def run_episode(self, episode_num: int, n_facts: int = 2) -> Dict: | |
| """Run one full episode""" | |
| # Reset environment | |
| obs = self.env.reset(n_facts=n_facts) | |
| session_id = obs["session_id"] | |
| meta = obs["metadata"] | |
| # Start logging | |
| self.logger.start_episode(episode_num, meta) | |
| # Build conversation histories | |
| interrogator_hist = [{ | |
| "role": "system", | |
| "content": interrogator_system( | |
| meta["crime_description"], | |
| meta["fact_keys"], | |
| self.config.MAX_TURNS | |
| ) | |
| }] | |
| suspect_hist = [{ | |
| "role": "system", | |
| "content": suspect_system( | |
| meta["crime_description"], | |
| meta["suspect_name"], | |
| meta["fake_alibi"] | |
| ) | |
| }] | |
| # Run turns | |
| for turn in range(self.config.MAX_TURNS): | |
| # Interrogator asks | |
| question = self.model.generate(interrogator_hist, temp=0.9) | |
| # Check for early accusation | |
| try: | |
| s, e = question.find("{"), question.rfind("}") + 1 | |
| if s >= 0 and e > s: | |
| accusation = json.loads(question[s:e]) | |
| if any(k in accusation for k in meta["fact_keys"]): | |
| # Clean accusation for 422 prevention | |
| if isinstance(accusation, dict): | |
| for k in accusation: | |
| accusation[k] = str(accusation[k]) if accusation[k] is not None else "[No content]" | |
| result = self.env.step( | |
| session_id, | |
| "submit_accusation", | |
| accusation_json=accusation | |
| ) | |
| self.logger.log_turn(question, "[ACCUSATION]", turn) | |
| self.logger.end_episode(result) | |
| return result | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| # Send question to env | |
| interrogator_hist.append({"role": "assistant", "content": question}) | |
| suspect_hist.append({"role": "user", "content": question}) | |
| obs = self.env.step(session_id, "question", content=question) | |
| if obs.get("done"): | |
| self.logger.end_episode(obs) | |
| return obs | |
| # Suspect answers | |
| answer = self.model.generate(suspect_hist, temp=0.85) | |
| suspect_hist.append({"role": "assistant", "content": answer}) | |
| interrogator_hist.append({"role": "user", "content": answer}) | |
| # Log turn | |
| self.logger.log_turn(question, answer, turn) | |
| # Send answer to env | |
| obs = self.env.step(session_id, "suspect_answer", content=answer) | |
| if obs.get("done"): | |
| self.logger.end_episode(obs) | |
| return obs | |
| # Force final accusation | |
| final_prompt = interrogator_hist + [{ | |
| "role": "user", | |
| "content": f"Time is up. Output ONLY a JSON with keys {meta['fact_keys']}. No explanation.", | |
| }] | |
| final_text = self.model.generate(final_prompt, temp=0.1, max_new_tokens=150) | |
| try: | |
| s, e = final_text.find("{"), final_text.rfind("}") + 1 | |
| if s >= 0 and e > s: | |
| accusation = json.loads(final_text[s:e]) | |
| # Ensure all values are strings to prevent 422 errors | |
| if isinstance(accusation, dict): | |
| for k in accusation: | |
| if accusation[k] is None: | |
| accusation[k] = "[No evidence found]" | |
| else: | |
| accusation[k] = str(accusation[k]) | |
| else: | |
| accusation = {k: "[No evidence found]" for k in meta["fact_keys"]} | |
| except Exception: | |
| accusation = {k: "[No evidence found]" for k in meta["fact_keys"]} | |
| result = self.env.step(session_id, "submit_accusation", accusation_json=accusation) | |
| self.logger.end_episode(result) | |
| return result | |
| def reward_function(self, prompts, completions, **kwargs) -> List[float]: | |
| """Reward function for GRPO""" | |
| rewards = [] | |
| # This is used for training the model. | |
| # GRPOTrainer provides 'completions' which are the model's responses to 'prompts'. | |
| # We need to evaluate the model's actual completions to provide learning signal. | |
| # Determine current batch size | |
| current_batch_size = len(completions) | |
| for i in range(current_batch_size): | |
| prompt = prompts[i] | |
| completion = completions[i] | |
| try: | |
| reward = 0.0 | |
| # 1. Format Reward: Is it or does it contain a question? | |
| if "?" in completion: | |
| reward += 0.1 | |
| # 2. Length Reward: Discourage empty or overly long responses | |
| if 10 < len(completion) < 300: | |
| reward += 0.05 | |
| # 3. JSON Structure Reward: If it's a final turn, is it valid JSON? | |
| is_final = False | |
| if isinstance(prompt, list) and len(prompt) > 0: | |
| last_msg = prompt[-1].get("content", "").lower() | |
| if "final" in last_msg or "json" in last_msg: | |
| is_final = True | |
| if is_final: | |
| # Look for JSON in completion | |
| match = re.search(r"\{.*\}", completion, re.DOTALL) | |
| if match: | |
| try: | |
| json.loads(match.group()) | |
| reward += 0.4 | |
| except: | |
| reward -= 0.1 | |
| else: | |
| reward -= 0.2 | |
| # 4. Environment-based reward | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| episode_num = len(self.episode_rewards) | |
| n_facts = random.randint(self.config.MIN_FACTS, self.config.MAX_FACTS) | |
| result = self.run_episode(episode_num, n_facts=n_facts) | |
| env_reward = float(result.get("reward", 0.0)) | |
| if env_reward > 0.201: | |
| reward += (env_reward - 0.2) * 2.0 | |
| self.episode_rewards.append({ | |
| "episode": episode_num, | |
| "reward": env_reward, | |
| "extraction_rate": result["metadata"].get("extraction_rate", 0), | |
| "facts_extracted": result["metadata"].get("facts_extracted", 0), | |
| "total_facts": result["metadata"].get("total_facts", 0), | |
| }) | |
| print(f" Batch [{i}/{current_batch_size}] Episode {episode_num}: env_reward={env_reward:.3f}, total_reward={reward:.3f}, " | |
| f"extracted={result['metadata'].get('facts_extracted', 0)}/" | |
| f"{result['metadata'].get('total_facts', 0)}") | |
| except Exception as e: | |
| print(f" Reward calculation error at index {i}: {e}") | |
| reward = 0.0 | |
| rewards.append(reward) | |
| return rewards | |
| def train(self): | |
| """Run full training loop with GRPO""" | |
| print("\n" + "="*60) | |
| print("STARTING TRAINING") | |
| print("="*60) | |
| # Create dummy dataset (crime descriptions as prompts) | |
| train_dataset = [ | |
| {"prompt": f"Investigate crime scenario {i}"} | |
| for i in range(self.config.NUM_EPISODES) | |
| ] | |
| # GRPO config | |
| training_args = GRPOConfig( | |
| num_generations=self.config.NUM_GENERATIONS, | |
| per_device_train_batch_size=self.config.BATCH_SIZE, | |
| gradient_accumulation_steps=self.config.GRADIENT_ACCUM, | |
| max_steps=self.config.NUM_EPISODES, | |
| max_prompt_length=self.config.MAX_SEQ_LENGTH // 4, | |
| max_completion_length=self.config.MAX_SEQ_LENGTH // 8, | |
| learning_rate=1e-5, # Slightly higher for faster learning | |
| logging_steps=1, # Log every step to see weight updates | |
| save_steps=10, | |
| output_dir=self.config.OUTPUT_DIR, | |
| report_to="none", | |
| remove_unused_columns=False, | |
| bf16=True, | |
| warmup_ratio=0.1, # Add warmup for stability | |
| ) | |
| # Create trainer | |
| trainer = GRPOTrainer( | |
| model=self.model.peft_model, | |
| args=training_args, | |
| processing_class=self.model.tokenizer, | |
| train_dataset=train_dataset, | |
| reward_funcs=[self.reward_function], | |
| ) | |
| # Train | |
| print(f"Training for {self.config.NUM_EPISODES} episodes...\n") | |
| trainer.train() | |
| print("\n" + "="*60) | |
| print("TRAINING COMPLETE") | |
| print("="*60) | |
| # Save final model | |
| self.model.save(self.config.CHECKPOINT_DIR + "/final") | |
| # Save training stats | |
| stats_path = Path(self.config.OUTPUT_DIR) / "training_stats.json" | |
| with open(stats_path, "w") as f: | |
| json.dump(self.episode_rewards, f, indent=2) | |
| return self.episode_rewards | |
| # ============================================================================ | |
| # GRADIO UI | |
| # ============================================================================ | |
| def create_gradio_interface(trainer: SuspectXTrainer): | |
| """Create Gradio monitoring interface""" | |
| def start_training(): | |
| """Training button callback""" | |
| try: | |
| results = trainer.train() | |
| # Calculate stats | |
| if not results: | |
| return "โ Training Loop finished (no results recorded)", [] | |
| avg_reward = sum(r["reward"] for r in results) / len(results) | |
| final_reward = sum(r["reward"] for r in results[-10:]) / min(10, len(results)) | |
| summary = f""" | |
| โ Training Complete! | |
| Total Episodes: {len(results)} | |
| Average Reward: {avg_reward:.3f} | |
| Final 10 Avg: {final_reward:.3f} | |
| Model saved to: {trainer.config.CHECKPOINT_DIR}/final | |
| Logs saved to: {trainer.config.LOGS_DIR} | |
| """ | |
| return summary, results[-20:] | |
| except Exception as e: | |
| return f"โ Training failed: {str(e)}", [] | |
| def test_episode(): | |
| """Run single test episode""" | |
| try: | |
| episode_num = len(trainer.episode_rewards) | |
| result = trainer.run_episode(episode_num, n_facts=2) | |
| output = f""" | |
| Episode: {episode_num} | |
| Reward: {result['reward']:.3f} | |
| Extraction Rate: {result['metadata'].get('extraction_rate', 0):.3f} | |
| Facts: {result['metadata'].get('facts_extracted', 0)}/{result['metadata'].get('total_facts', 0)} | |
| Crime: {result['metadata'].get('crime_description', 'N/A')} | |
| """ | |
| return output, result | |
| except Exception as e: | |
| return f"โ Test failed: {str(e)}", {} | |
| def get_stats(): | |
| """Show current training stats""" | |
| if not trainer.episode_rewards: | |
| return "No training data yet", [] | |
| recent = trainer.episode_rewards[-20:] | |
| avg = sum(r["reward"] for r in recent) / len(recent) | |
| stats = f""" | |
| Total Episodes: {len(trainer.episode_rewards)} | |
| Recent Avg Reward (last 20): {avg:.3f} | |
| """ | |
| return stats, recent | |
| # Build UI | |
| with gr.Blocks(title="Suspect X Training") as interface: | |
| gr.Markdown("# ๐ต๏ธ Suspect X - RL Training Dashboard") | |
| gr.Markdown(f"Environment: `{trainer.config.HF_ENV_URL}`") | |
| with gr.Tab("Training"): | |
| train_btn = gr.Button("๐ Start Training", variant="primary", size="lg") | |
| train_output = gr.Textbox(label="Training Status", lines=10) | |
| train_data = gr.JSON(label="Recent Episodes") | |
| train_btn.click(start_training, outputs=[train_output, train_data]) | |
| with gr.Tab("Test Episode"): | |
| test_btn = gr.Button("โถ๏ธ Run Test Episode", variant="secondary") | |
| test_output = gr.Textbox(label="Episode Result", lines=10) | |
| test_data = gr.JSON(label="Full Result") | |
| test_btn.click(test_episode, outputs=[test_output, test_data]) | |
| with gr.Tab("Stats"): | |
| stats_btn = gr.Button("๐ Refresh Stats") | |
| stats_output = gr.Textbox(label="Training Stats", lines=5) | |
| stats_data = gr.JSON(label="Recent Episodes") | |
| stats_btn.click(get_stats, outputs=[stats_output, stats_data]) | |
| return interface | |
| # ============================================================================ | |
| # MAIN | |
| # ============================================================================ | |
| def main(): | |
| # Load config | |
| config = Config() | |
| # Update HF_ENV_URL from environment variable if available | |
| if os.getenv("HF_ENV_URL"): | |
| config.HF_ENV_URL = os.getenv("HF_ENV_URL") | |
| print("="*60) | |
| print("SUSPECT X - RL TRAINING") | |
| print("="*60) | |
| print(f"Environment URL: {config.HF_ENV_URL}") | |
| print(f"Model: {config.MODEL_NAME}") | |
| print(f"Episodes: {config.NUM_EPISODES}") | |
| print("="*60) | |
| # Initialize components | |
| env_client = EnvironmentClient(config.HF_ENV_URL) | |
| # Test connection | |
| if not env_client.test_connection(): | |
| print("\nโ Cannot connect to environment. Please check HF_ENV_URL") | |
| return | |
| logger = ConversationLogger(config.LOGS_DIR) | |
| # Load model | |
| model = SuspectXModel(config) | |
| model.load_model() | |
| # Create trainer | |
| trainer = SuspectXTrainer(model, env_client, logger, config) | |
| # Launch Gradio UI | |
| print("\n๐ Launching Gradio interface...") | |
| interface = create_gradio_interface(trainer) | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |