saiteki-kai commited on
Commit
723e9ef
·
verified ·
1 Parent(s): cdc66db

feat: setup space

Browse files
Files changed (5) hide show
  1. .python-version +1 -0
  2. app.py +197 -0
  3. pyproject.toml +18 -0
  4. requirements.txt +6 -0
  5. uv.lock +0 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import time
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ AutoModelForSequenceClassification,
9
+ )
10
+
11
+ # ============================================================================
12
+ # Environment Setup
13
+ # ============================================================================
14
+
15
+ print("\n=== Environment Setup ===")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ if torch.cuda.is_available():
18
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
19
+ else:
20
+ print("Using CPU")
21
+
22
+ # ============================================================================
23
+ # Model Configuration
24
+ # ============================================================================
25
+
26
+ CHAT_MODEL_NAME = "sapienzanlp/Minerva-7B-instruct-v1.0"
27
+ CLASSIFIER_MODEL_NAME = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
28
+
29
+ # Generation parameters
30
+ MAX_NEW_TOKENS = 256
31
+ REPETITION_PENALTY = 1.1
32
+ MAX_INPUT_LENGTH = 512
33
+ MAX_CLASSIFIER_LENGTH = 512
34
+
35
+ # ============================================================================
36
+ # Model Loading
37
+ # ============================================================================
38
+
39
+ print("\n=== Loading Models ===")
40
+
41
+ # Chat model setup
42
+ print(f"Loading chat model: {CHAT_MODEL_NAME}")
43
+
44
+ chat_tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME, padding_side="left")
45
+
46
+ if chat_tokenizer.pad_token is None:
47
+ chat_tokenizer.pad_token = chat_tokenizer.eos_token
48
+
49
+ chat_model = AutoModelForCausalLM.from_pretrained(CHAT_MODEL_NAME, torch_dtype=torch.bfloat16)
50
+
51
+ chat_model.to(device) # type: ignore
52
+ chat_model.eval()
53
+
54
+ print("✓ Chat model loaded")
55
+
56
+ # Classifier setup
57
+ print(f"Loading classifier: {CLASSIFIER_MODEL_NAME}")
58
+
59
+ cls_tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME)
60
+ cls_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_MODEL_NAME, torch_dtype=torch.bfloat16)
61
+
62
+ cls_model.to(device)
63
+ cls_model.eval()
64
+
65
+ # Get the index for "unsafe" label
66
+ UNSAFE_IDX = cls_model.config.label2id["unsafe"]
67
+
68
+ print("✓ Classifier loaded")
69
+
70
+ # ============================================================================
71
+ # Generation Function
72
+ # ============================================================================
73
+
74
+
75
+ @spaces.GPU(duration=90)
76
+ def generate(submission: list[dict[str, str]], team_id: str) -> list[dict[str, str | float]]:
77
+ """
78
+ Generate responses for prompts and classify their safety.
79
+
80
+ Args:
81
+ submission: List of dicts with 'id' and 'prompt' keys
82
+ team_id: Team identifier
83
+
84
+ Returns:
85
+ List of dicts with id, prompt, response, score, model, and team_id
86
+ """
87
+ print(f"\n=== Processing batch of {len(submission)} prompts ===")
88
+
89
+ # Extract data from submission
90
+ ids = [s["id"] for s in submission]
91
+ prompts = [s["prompt"] for s in submission]
92
+
93
+ # ------------------------------------------------------------------------
94
+ # Step 1: Generate Responses
95
+ # ------------------------------------------------------------------------
96
+ start_time = time.perf_counter()
97
+
98
+ # Format prompts using chat template
99
+ messages_list = [[{"role": "user", "content": prompt}] for prompt in prompts]
100
+ formatted_prompts = [
101
+ chat_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
+ for messages in messages_list
103
+ ]
104
+
105
+ # Tokenize all prompts in batch
106
+ inputs = chat_tokenizer(
107
+ formatted_prompts,
108
+ return_tensors="pt",
109
+ padding=True,
110
+ truncation=True,
111
+ max_length=MAX_INPUT_LENGTH,
112
+ ).to(device)
113
+
114
+ # Generate responses
115
+ with torch.no_grad():
116
+ output_ids = chat_model.generate(
117
+ **inputs,
118
+ max_new_tokens=MAX_NEW_TOKENS,
119
+ do_sample=False,
120
+ repetition_penalty=REPETITION_PENALTY,
121
+ pad_token_id=chat_tokenizer.pad_token_id,
122
+ eos_token_id=chat_tokenizer.eos_token_id,
123
+ )
124
+
125
+ # Decode only newly generated tokens (exclude input)
126
+ generated_ids = output_ids[:, inputs.input_ids.shape[1] :]
127
+ responses = chat_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
128
+
129
+ generation_time = time.perf_counter() - start_time
130
+ print(f"✓ Generation completed: {generation_time:.3f}s ({len(prompts) / generation_time:.1f} prompts/s)")
131
+
132
+ # ------------------------------------------------------------------------
133
+ # Step 2: Classify Safety
134
+ # ------------------------------------------------------------------------
135
+ start_time = time.perf_counter()
136
+
137
+ # Tokenize prompt-response pairs
138
+ cls_inputs = cls_tokenizer(
139
+ prompts,
140
+ responses,
141
+ return_tensors="pt",
142
+ padding=True,
143
+ truncation=True,
144
+ max_length=MAX_CLASSIFIER_LENGTH,
145
+ ).to(device)
146
+
147
+ # Run classifier
148
+ with torch.no_grad():
149
+ cls_outputs = cls_model(**cls_inputs)
150
+ probs = torch.nn.functional.softmax(cls_outputs.logits, dim=-1)
151
+ unsafe_scores = probs[:, UNSAFE_IDX].cpu().tolist()
152
+
153
+ classification_time = time.perf_counter() - start_time
154
+ print(
155
+ f"✓ Classification completed: {classification_time:.3f}s ({len(prompts) / classification_time:.1f} prompts/s)"
156
+ )
157
+
158
+ # ------------------------------------------------------------------------
159
+ # Step 3: Format Output
160
+ # ------------------------------------------------------------------------
161
+ results = [
162
+ {
163
+ "id": id_,
164
+ "prompt": prompt,
165
+ "response": response,
166
+ "score": score,
167
+ "model": CHAT_MODEL_NAME,
168
+ "team_id": team_id,
169
+ }
170
+ for id_, prompt, response, score in zip(ids, prompts, responses, unsafe_scores)
171
+ ]
172
+
173
+ total_time = generation_time + classification_time
174
+ print(f"✓ Total processing time: {total_time:.3f}s")
175
+ print(f"✓ Average time per prompt: {total_time / len(prompts):.3f}s")
176
+
177
+ return results
178
+
179
+
180
+ # ============================================================================
181
+ # Gradio Interface
182
+ # ============================================================================
183
+
184
+ print("\n=== Setting up Gradio Interface ===")
185
+
186
+ with gr.Blocks() as demo:
187
+ gr.api(generate, api_name="scores", concurrency_limit=None, batch=False)
188
+
189
+ # ============================================================================
190
+ # Launch
191
+ # ============================================================================
192
+
193
+ if __name__ == "__main__":
194
+ print("\n=== Launching Application ===")
195
+ demo.queue(default_concurrency_limit=None, api_open=True)
196
+ demo.launch()
197
+ print("✓ Application running")
pyproject.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "thesafetygame-zerogpu"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "accelerate>=1.12.0",
9
+ "bitsandbytes>=0.48.2",
10
+ "gradio==6.0.2",
11
+ "optimum>=2.0.0",
12
+ "spaces>=0.44.0",
13
+ "torch>=2.9.1",
14
+ "transformers>=4.57.3",
15
+ ]
16
+
17
+ [tool.ruff]
18
+ line-length = 120
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spaces
2
+ torch
3
+ transformers
4
+ optimum
5
+ accelerate
6
+ bitsandbytes
uv.lock ADDED
The diff for this file is too large to render. See raw diff