trainer / app.py
mayank1365
Fix CUDA device-side assert by adjusting max_prompt_length and disabling use_cache
fe123ff
"""
Suspect X - RL Training Script
Runs on Hugging Face Spaces with GPU
Trains Qwen model via GRPO using remote environment
"""
import os
import json
import time
import random
import re
import requests
import torch
# Workaround for Hugging Face Spaces GPU detection issues with Unsloth/Bitsandbytes
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import gradio as gr
from datetime import datetime
from typing import Dict, List, Optional
from pathlib import Path
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer
from transformers import TrainingArguments
# ============================================================================
# CONFIGURATION
# ============================================================================
class Config:
# Environment URL - EDIT THIS
HF_ENV_URL = "https://ayaan-ai-meta.hf.space"
# Model settings
MODEL_NAME = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
MAX_SEQ_LENGTH = 4096
LOAD_IN_4BIT = True
# LoRA settings
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.0 # Unsloth works best with 0 dropout
# Training settings
NUM_EPISODES = 50 # Reduced for stability
MAX_TURNS = 8 # Reduced to prevent timeout
MIN_FACTS = 2
MAX_FACTS = 4
LEARNING_RATE = 5e-6
BATCH_SIZE = 1
GRADIENT_ACCUM = 4
NUM_GENERATIONS = 2 # Reduced from 4 for stability
# Paths
OUTPUT_DIR = "./suspect_x_output"
LOGS_DIR = "./conversation_logs"
CHECKPOINT_DIR = "./checkpoints"
# ============================================================================
# ENVIRONMENT CLIENT
# ============================================================================
class EnvironmentClient:
def __init__(self, base_url: str):
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
def test_connection(self) -> bool:
"""Test if environment is reachable"""
try:
resp = self.session.get(f"{self.base_url}/", timeout=10)
resp.raise_for_status()
print(f"โœ… Connected to environment: {resp.json()}")
return True
except Exception as e:
print(f"โŒ Connection failed: {e}")
return False
def reset(self, n_facts: int = 2, difficulty: Optional[str] = None) -> Dict:
"""Start new episode"""
resp = self.session.post(
f"{self.base_url}/reset",
json={"n_facts": n_facts, "difficulty": difficulty},
timeout=30,
)
resp.raise_for_status()
return resp.json()
def step(
self,
session_id: str,
action_type: str,
content: Optional[str] = None,
accusation_json: Optional[Dict] = None,
) -> Dict:
"""Take action in episode"""
payload = {"session_id": session_id, "action_type": action_type}
if content is not None:
payload["content"] = str(content)
else:
payload["content"] = "" # Ensure content is never None for 422 errors
if accusation_json is not None:
payload["accusation_json"] = accusation_json
try:
resp = self.session.post(
f"{self.base_url}/step",
json=payload,
timeout=30,
)
if resp.status_code == 422:
print(f"โŒ 422 Error Detail: {resp.text}")
resp.raise_for_status()
return resp.json()
except Exception as e:
# Return a 'done' state so the trainer can continue
return {"done": True, "reward": 0.0, "metadata": {"error": str(e)}}
# ============================================================================
# CONVERSATION LOGGER
# ============================================================================
class ConversationLogger:
def __init__(self, log_dir: str):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.current_episode = None
def start_episode(self, episode_num: int, metadata: Dict):
"""Start logging new episode"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.current_episode = {
"episode_num": episode_num,
"timestamp": timestamp,
"metadata": metadata,
"turns": [],
"result": None,
}
def log_turn(self, question: str, answer: str, turn_num: int):
"""Log Q&A turn"""
if self.current_episode:
self.current_episode["turns"].append({
"turn": turn_num,
"question": question,
"answer": answer,
})
def end_episode(self, result: Dict):
"""Save episode to file"""
if self.current_episode:
self.current_episode["result"] = result
filename = (
f"episode_{self.current_episode['episode_num']:04d}_"
f"{self.current_episode['timestamp']}.json"
)
filepath = self.log_dir / filename
with open(filepath, "w") as f:
json.dump(self.current_episode, f, indent=2)
self.current_episode = None
return str(filepath)
# ============================================================================
# MODEL WRAPPER
# ============================================================================
class SuspectXModel:
def __init__(self, config: Config):
self.config = config
self.model = None
self.tokenizer = None
self.peft_model = None
def load_model(self):
"""Load base model and setup LoRA"""
print("Loading model...")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
self.config.MODEL_NAME,
max_seq_length=self.config.MAX_SEQ_LENGTH,
load_in_4bit=self.config.LOAD_IN_4BIT,
)
# Fix tokenizer pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
# Set padding side to left for generation
self.tokenizer.padding_side = "left"
FastLanguageModel.for_training(self.model)
# Setup LoRA for interrogator
self.peft_model = FastLanguageModel.get_peft_model(
self.model,
r=self.config.LORA_R,
lora_alpha=self.config.LORA_ALPHA,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=self.config.LORA_DROPOUT,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
self.peft_model.config.use_cache = False # Critical for training stability
print(f"โœ… Model loaded: {self.config.MODEL_NAME}")
print(f" Trainable params: {sum(p.numel() for p in self.peft_model.parameters() if p.requires_grad):,}")
def generate(
self,
messages: List[Dict],
temp: float = 0.9,
max_new_tokens: int = 200,
) -> str:
"""Generate text from messages"""
try:
# Apply chat template and tokenize
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize with proper padding and truncation
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.config.MAX_SEQ_LENGTH - max_new_tokens,
return_attention_mask=True,
).to("cuda")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Clamp temperature to avoid NaN
safe_temp = max(0.1, min(temp, 2.0))
with torch.no_grad():
outputs = self.peft_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
temperature=safe_temp,
do_sample=(safe_temp > 0),
top_p=0.95,
top_k=50,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
)
# Decode only new tokens
new_tokens = outputs[0][input_ids.shape[-1]:]
response = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
return response if response else "[No response]"
except Exception as e:
print(f"Generation error: {e}")
return "[Generation failed]"
def save(self, path: str):
"""Save model weights"""
save_path = Path(path)
save_path.mkdir(parents=True, exist_ok=True)
self.peft_model.save_pretrained(str(save_path))
self.tokenizer.save_pretrained(str(save_path))
print(f"โœ… Model saved to {save_path}")
# ============================================================================
# PROMPTS
# ============================================================================
def interrogator_system(crime: str, fact_keys: List[str], turns_left: int) -> str:
return f"""You are Detective Sarah Chen investigating a crime.
Crime: {crime}
Hidden facts to uncover: {', '.join(fact_keys)}
Turns remaining: {turns_left}
Strategy:
- Ask specific, targeted questions about timeline, locations, and relationships
- Return to the same topic twice to catch contradictions
- Use false premises to bait confessions
- Build a timeline and expose gaps
When confident (or time is up), output ONLY a JSON object with the hidden fact keys.
Example: {{{', '.join(f'"{k}": "your answer here"' for k in fact_keys)}}}"""
def suspect_system(crime: str, name: str, fake_alibi: str) -> str:
return f"""You are {name}, being interrogated about: {crime}
Your cover story: {fake_alibi}
Rules:
- NEVER directly reveal the truth
- Stay consistent โ€” never contradict your previous answers
- Answer every question (refusing too often loses points)
- You may lie or mislead, but maintain your story"""
# ============================================================================
# RL TRAINER
# ============================================================================
class SuspectXTrainer:
def __init__(
self,
model: SuspectXModel,
env_client: EnvironmentClient,
logger: ConversationLogger,
config: Config,
):
self.model = model
self.env = env_client
self.logger = logger
self.config = config
self.episode_rewards = []
self.training_stats = []
def run_episode(self, episode_num: int, n_facts: int = 2) -> Dict:
"""Run one full episode"""
# Reset environment
obs = self.env.reset(n_facts=n_facts)
session_id = obs["session_id"]
meta = obs["metadata"]
# Start logging
self.logger.start_episode(episode_num, meta)
# Build conversation histories
interrogator_hist = [{
"role": "system",
"content": interrogator_system(
meta["crime_description"],
meta["fact_keys"],
self.config.MAX_TURNS
)
}]
suspect_hist = [{
"role": "system",
"content": suspect_system(
meta["crime_description"],
meta["suspect_name"],
meta["fake_alibi"]
)
}]
# Run turns
for turn in range(self.config.MAX_TURNS):
# Interrogator asks
question = self.model.generate(interrogator_hist, temp=0.9)
# Check for early accusation
try:
s, e = question.find("{"), question.rfind("}") + 1
if s >= 0 and e > s:
accusation = json.loads(question[s:e])
if any(k in accusation for k in meta["fact_keys"]):
# Clean accusation for 422 prevention
if isinstance(accusation, dict):
for k in accusation:
accusation[k] = str(accusation[k]) if accusation[k] is not None else "[No content]"
result = self.env.step(
session_id,
"submit_accusation",
accusation_json=accusation
)
self.logger.log_turn(question, "[ACCUSATION]", turn)
self.logger.end_episode(result)
return result
except (json.JSONDecodeError, ValueError):
pass
# Send question to env
interrogator_hist.append({"role": "assistant", "content": question})
suspect_hist.append({"role": "user", "content": question})
obs = self.env.step(session_id, "question", content=question)
if obs.get("done"):
self.logger.end_episode(obs)
return obs
# Suspect answers
answer = self.model.generate(suspect_hist, temp=0.85)
suspect_hist.append({"role": "assistant", "content": answer})
interrogator_hist.append({"role": "user", "content": answer})
# Log turn
self.logger.log_turn(question, answer, turn)
# Send answer to env
obs = self.env.step(session_id, "suspect_answer", content=answer)
if obs.get("done"):
self.logger.end_episode(obs)
return obs
# Force final accusation
final_prompt = interrogator_hist + [{
"role": "user",
"content": f"Time is up. Output ONLY a JSON with keys {meta['fact_keys']}. No explanation.",
}]
final_text = self.model.generate(final_prompt, temp=0.1, max_new_tokens=150)
try:
s, e = final_text.find("{"), final_text.rfind("}") + 1
if s >= 0 and e > s:
accusation = json.loads(final_text[s:e])
# Ensure all values are strings to prevent 422 errors
if isinstance(accusation, dict):
for k in accusation:
if accusation[k] is None:
accusation[k] = "[No evidence found]"
else:
accusation[k] = str(accusation[k])
else:
accusation = {k: "[No evidence found]" for k in meta["fact_keys"]}
except Exception:
accusation = {k: "[No evidence found]" for k in meta["fact_keys"]}
result = self.env.step(session_id, "submit_accusation", accusation_json=accusation)
self.logger.end_episode(result)
return result
def reward_function(self, prompts, completions, **kwargs) -> List[float]:
"""Reward function for GRPO"""
rewards = []
# This is used for training the model.
# GRPOTrainer provides 'completions' which are the model's responses to 'prompts'.
# We need to evaluate the model's actual completions to provide learning signal.
# Determine current batch size
current_batch_size = len(completions)
for i in range(current_batch_size):
prompt = prompts[i]
completion = completions[i]
try:
reward = 0.0
# 1. Format Reward: Is it or does it contain a question?
if "?" in completion:
reward += 0.1
# 2. Length Reward: Discourage empty or overly long responses
if 10 < len(completion) < 300:
reward += 0.05
# 3. JSON Structure Reward: If it's a final turn, is it valid JSON?
is_final = False
if isinstance(prompt, list) and len(prompt) > 0:
last_msg = prompt[-1].get("content", "").lower()
if "final" in last_msg or "json" in last_msg:
is_final = True
if is_final:
# Look for JSON in completion
match = re.search(r"\{.*\}", completion, re.DOTALL)
if match:
try:
json.loads(match.group())
reward += 0.4
except:
reward -= 0.1
else:
reward -= 0.2
# 4. Environment-based reward
if torch.cuda.is_available():
torch.cuda.empty_cache()
episode_num = len(self.episode_rewards)
n_facts = random.randint(self.config.MIN_FACTS, self.config.MAX_FACTS)
result = self.run_episode(episode_num, n_facts=n_facts)
env_reward = float(result.get("reward", 0.0))
if env_reward > 0.201:
reward += (env_reward - 0.2) * 2.0
self.episode_rewards.append({
"episode": episode_num,
"reward": env_reward,
"extraction_rate": result["metadata"].get("extraction_rate", 0),
"facts_extracted": result["metadata"].get("facts_extracted", 0),
"total_facts": result["metadata"].get("total_facts", 0),
})
print(f" Batch [{i}/{current_batch_size}] Episode {episode_num}: env_reward={env_reward:.3f}, total_reward={reward:.3f}, "
f"extracted={result['metadata'].get('facts_extracted', 0)}/"
f"{result['metadata'].get('total_facts', 0)}")
except Exception as e:
print(f" Reward calculation error at index {i}: {e}")
reward = 0.0
rewards.append(reward)
return rewards
def train(self):
"""Run full training loop with GRPO"""
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
# Create dummy dataset (crime descriptions as prompts)
train_dataset = [
{"prompt": f"Investigate crime scenario {i}"}
for i in range(self.config.NUM_EPISODES)
]
# GRPO config
training_args = GRPOConfig(
num_generations=self.config.NUM_GENERATIONS,
per_device_train_batch_size=self.config.BATCH_SIZE,
gradient_accumulation_steps=self.config.GRADIENT_ACCUM,
max_steps=self.config.NUM_EPISODES,
max_prompt_length=self.config.MAX_SEQ_LENGTH // 4,
max_completion_length=self.config.MAX_SEQ_LENGTH // 8,
learning_rate=1e-5, # Slightly higher for faster learning
logging_steps=1, # Log every step to see weight updates
save_steps=10,
output_dir=self.config.OUTPUT_DIR,
report_to="none",
remove_unused_columns=False,
bf16=True,
warmup_ratio=0.1, # Add warmup for stability
)
# Create trainer
trainer = GRPOTrainer(
model=self.model.peft_model,
args=training_args,
processing_class=self.model.tokenizer,
train_dataset=train_dataset,
reward_funcs=[self.reward_function],
)
# Train
print(f"Training for {self.config.NUM_EPISODES} episodes...\n")
trainer.train()
print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
# Save final model
self.model.save(self.config.CHECKPOINT_DIR + "/final")
# Save training stats
stats_path = Path(self.config.OUTPUT_DIR) / "training_stats.json"
with open(stats_path, "w") as f:
json.dump(self.episode_rewards, f, indent=2)
return self.episode_rewards
# ============================================================================
# GRADIO UI
# ============================================================================
def create_gradio_interface(trainer: SuspectXTrainer):
"""Create Gradio monitoring interface"""
def start_training():
"""Training button callback"""
try:
results = trainer.train()
# Calculate stats
if not results:
return "โœ… Training Loop finished (no results recorded)", []
avg_reward = sum(r["reward"] for r in results) / len(results)
final_reward = sum(r["reward"] for r in results[-10:]) / min(10, len(results))
summary = f"""
โœ… Training Complete!
Total Episodes: {len(results)}
Average Reward: {avg_reward:.3f}
Final 10 Avg: {final_reward:.3f}
Model saved to: {trainer.config.CHECKPOINT_DIR}/final
Logs saved to: {trainer.config.LOGS_DIR}
"""
return summary, results[-20:]
except Exception as e:
return f"โŒ Training failed: {str(e)}", []
def test_episode():
"""Run single test episode"""
try:
episode_num = len(trainer.episode_rewards)
result = trainer.run_episode(episode_num, n_facts=2)
output = f"""
Episode: {episode_num}
Reward: {result['reward']:.3f}
Extraction Rate: {result['metadata'].get('extraction_rate', 0):.3f}
Facts: {result['metadata'].get('facts_extracted', 0)}/{result['metadata'].get('total_facts', 0)}
Crime: {result['metadata'].get('crime_description', 'N/A')}
"""
return output, result
except Exception as e:
return f"โŒ Test failed: {str(e)}", {}
def get_stats():
"""Show current training stats"""
if not trainer.episode_rewards:
return "No training data yet", []
recent = trainer.episode_rewards[-20:]
avg = sum(r["reward"] for r in recent) / len(recent)
stats = f"""
Total Episodes: {len(trainer.episode_rewards)}
Recent Avg Reward (last 20): {avg:.3f}
"""
return stats, recent
# Build UI
with gr.Blocks(title="Suspect X Training") as interface:
gr.Markdown("# ๐Ÿ•ต๏ธ Suspect X - RL Training Dashboard")
gr.Markdown(f"Environment: `{trainer.config.HF_ENV_URL}`")
with gr.Tab("Training"):
train_btn = gr.Button("๐Ÿš€ Start Training", variant="primary", size="lg")
train_output = gr.Textbox(label="Training Status", lines=10)
train_data = gr.JSON(label="Recent Episodes")
train_btn.click(start_training, outputs=[train_output, train_data])
with gr.Tab("Test Episode"):
test_btn = gr.Button("โ–ถ๏ธ Run Test Episode", variant="secondary")
test_output = gr.Textbox(label="Episode Result", lines=10)
test_data = gr.JSON(label="Full Result")
test_btn.click(test_episode, outputs=[test_output, test_data])
with gr.Tab("Stats"):
stats_btn = gr.Button("๐Ÿ“Š Refresh Stats")
stats_output = gr.Textbox(label="Training Stats", lines=5)
stats_data = gr.JSON(label="Recent Episodes")
stats_btn.click(get_stats, outputs=[stats_output, stats_data])
return interface
# ============================================================================
# MAIN
# ============================================================================
def main():
# Load config
config = Config()
# Update HF_ENV_URL from environment variable if available
if os.getenv("HF_ENV_URL"):
config.HF_ENV_URL = os.getenv("HF_ENV_URL")
print("="*60)
print("SUSPECT X - RL TRAINING")
print("="*60)
print(f"Environment URL: {config.HF_ENV_URL}")
print(f"Model: {config.MODEL_NAME}")
print(f"Episodes: {config.NUM_EPISODES}")
print("="*60)
# Initialize components
env_client = EnvironmentClient(config.HF_ENV_URL)
# Test connection
if not env_client.test_connection():
print("\nโŒ Cannot connect to environment. Please check HF_ENV_URL")
return
logger = ConversationLogger(config.LOGS_DIR)
# Load model
model = SuspectXModel(config)
model.load_model()
# Create trainer
trainer = SuspectXTrainer(model, env_client, logger, config)
# Launch Gradio UI
print("\n๐Ÿš€ Launching Gradio interface...")
interface = create_gradio_interface(trainer)
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
)
if __name__ == "__main__":
main()