TheSafetyGame / app.py
saiteki-kai's picture
feat: setup space
723e9ef verified
raw
history blame
6.6 kB
import spaces
import time
import torch
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoModelForSequenceClassification,
)
# ============================================================================
# Environment Setup
# ============================================================================
print("\n=== Environment Setup ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
print("Using CPU")
# ============================================================================
# Model Configuration
# ============================================================================
CHAT_MODEL_NAME = "sapienzanlp/Minerva-7B-instruct-v1.0"
CLASSIFIER_MODEL_NAME = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
# Generation parameters
MAX_NEW_TOKENS = 256
REPETITION_PENALTY = 1.1
MAX_INPUT_LENGTH = 512
MAX_CLASSIFIER_LENGTH = 512
# ============================================================================
# Model Loading
# ============================================================================
print("\n=== Loading Models ===")
# Chat model setup
print(f"Loading chat model: {CHAT_MODEL_NAME}")
chat_tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME, padding_side="left")
if chat_tokenizer.pad_token is None:
chat_tokenizer.pad_token = chat_tokenizer.eos_token
chat_model = AutoModelForCausalLM.from_pretrained(CHAT_MODEL_NAME, torch_dtype=torch.bfloat16)
chat_model.to(device) # type: ignore
chat_model.eval()
print("✓ Chat model loaded")
# Classifier setup
print(f"Loading classifier: {CLASSIFIER_MODEL_NAME}")
cls_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME)
cls_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME, torch_dtype=torch.bfloat16)
cls_model.to(device)
cls_model.eval()
# Get the index for "unsafe" label
UNSAFE_IDX = cls_model.config.label2id["unsafe"]
print("✓ Classifier loaded")
# ============================================================================
# Generation Function
# ============================================================================
@spaces.GPU(duration=90)
def generate(submission: list[dict[str, str]], team_id: str) -> list[dict[str, str | float]]:
"""
Generate responses for prompts and classify their safety.
Args:
submission: List of dicts with 'id' and 'prompt' keys
team_id: Team identifier
Returns:
List of dicts with id, prompt, response, score, model, and team_id
"""
print(f"\n=== Processing batch of {len(submission)} prompts ===")
# Extract data from submission
ids = [s["id"] for s in submission]
prompts = [s["prompt"] for s in submission]
# ------------------------------------------------------------------------
# Step 1: Generate Responses
# ------------------------------------------------------------------------
start_time = time.perf_counter()
# Format prompts using chat template
messages_list = [[{"role": "user", "content": prompt}] for prompt in prompts]
formatted_prompts = [
chat_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
for messages in messages_list
]
# Tokenize all prompts in batch
inputs = chat_tokenizer(
formatted_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_INPUT_LENGTH,
).to(device)
# Generate responses
with torch.no_grad():
output_ids = chat_model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
repetition_penalty=REPETITION_PENALTY,
pad_token_id=chat_tokenizer.pad_token_id,
eos_token_id=chat_tokenizer.eos_token_id,
)
# Decode only newly generated tokens (exclude input)
generated_ids = output_ids[:, inputs.input_ids.shape[1] :]
responses = chat_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
generation_time = time.perf_counter() - start_time
print(f"✓ Generation completed: {generation_time:.3f}s ({len(prompts) / generation_time:.1f} prompts/s)")
# ------------------------------------------------------------------------
# Step 2: Classify Safety
# ------------------------------------------------------------------------
start_time = time.perf_counter()
# Tokenize prompt-response pairs
cls_inputs = cls_tokenizer(
prompts,
responses,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_CLASSIFIER_LENGTH,
).to(device)
# Run classifier
with torch.no_grad():
cls_outputs = cls_model(**cls_inputs)
probs = torch.nn.functional.softmax(cls_outputs.logits, dim=-1)
unsafe_scores = probs[:, UNSAFE_IDX].cpu().tolist()
classification_time = time.perf_counter() - start_time
print(
f"✓ Classification completed: {classification_time:.3f}s ({len(prompts) / classification_time:.1f} prompts/s)"
)
# ------------------------------------------------------------------------
# Step 3: Format Output
# ------------------------------------------------------------------------
results = [
{
"id": id_,
"prompt": prompt,
"response": response,
"score": score,
"model": CHAT_MODEL_NAME,
"team_id": team_id,
}
for id_, prompt, response, score in zip(ids, prompts, responses, unsafe_scores)
]
total_time = generation_time + classification_time
print(f"✓ Total processing time: {total_time:.3f}s")
print(f"✓ Average time per prompt: {total_time / len(prompts):.3f}s")
return results
# ============================================================================
# Gradio Interface
# ============================================================================
print("\n=== Setting up Gradio Interface ===")
with gr.Blocks() as demo:
gr.api(generate, api_name="scores", concurrency_limit=None, batch=False)
# ============================================================================
# Launch
# ============================================================================
if __name__ == "__main__":
print("\n=== Launching Application ===")
demo.queue(default_concurrency_limit=None, api_open=True)
demo.launch()
print("✓ Application running")