ryandt commited on
Commit
248ed0c
·
1 Parent(s): 50948cd

Update to the zero gpu config

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -88,11 +88,20 @@ def _run_beam_search_threaded(
88
  target_emb, encoder_name, prompt,
89
  beam_width, top_k, patience, max_steps, min_similarity, randomness,
90
  progress_queue,
 
91
  ):
92
- """Run beam search on GPU, pushing step updates to a queue."""
 
 
 
 
 
93
  llm, tokenizer = load_llm()
94
  encoder = load_encoder(encoder_name)
95
 
 
 
 
96
  step_count = 0
97
 
98
  def on_step(step, cand):
@@ -114,7 +123,7 @@ def _run_beam_search_threaded(
114
  )
115
  elapsed = time.time() - t0
116
  progress_queue.put(_SENTINEL)
117
- return result, elapsed, step_count
118
 
119
 
120
  def run_stage(
@@ -135,11 +144,6 @@ def run_stage(
135
 
136
  stage_num = len(stage_results_state) + 1
137
 
138
- # Encode target on first stage
139
- if stage_num == 1:
140
- encoder = load_encoder(encoder_name)
141
- target_emb_state = encode_text(text.strip(), encoder)
142
-
143
  # Build prompt
144
  if stage_num == 1:
145
  prompt = _STAGE1_PROMPT
@@ -147,21 +151,26 @@ def run_stage(
147
  prev_text = stage_results_state[-1]["text"]
148
  prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=prev_text)
149
 
 
 
 
150
  # Run beam search in a thread so we can yield progress
151
  progress_q = queue.Queue()
152
 
153
  # Container for the thread's return value
154
- result_holder = [None, 0.0, 0]
155
 
156
  def _worker():
157
- r, elapsed, steps = _run_beam_search_threaded(
158
  target_emb_state, encoder_name, prompt,
159
  beam_width, top_k, patience, max_steps, min_similarity, randomness,
160
  progress_q,
 
161
  )
162
  result_holder[0] = r
163
  result_holder[1] = elapsed
164
  result_holder[2] = steps
 
165
 
166
  worker = threading.Thread(target=_worker)
167
  worker.start()
@@ -190,7 +199,9 @@ def run_stage(
190
 
191
  worker.join()
192
 
193
- result, elapsed, steps = result_holder
 
 
194
  stage_results_state = stage_results_state + [{
195
  "stage": stage_num,
196
  "text": result.seq_str,
 
88
  target_emb, encoder_name, prompt,
89
  beam_width, top_k, patience, max_steps, min_similarity, randomness,
90
  progress_queue,
91
+ encode_text_input=None,
92
  ):
93
+ """Run beam search on GPU, pushing step updates to a queue.
94
+
95
+ If encode_text_input is provided and target_emb is None, encodes
96
+ the text to produce the target embedding (Stage 1). This keeps
97
+ all CUDA operations inside the @spaces.GPU context.
98
+ """
99
  llm, tokenizer = load_llm()
100
  encoder = load_encoder(encoder_name)
101
 
102
+ if target_emb is None and encode_text_input is not None:
103
+ target_emb = encode_text(encode_text_input, encoder)
104
+
105
  step_count = 0
106
 
107
  def on_step(step, cand):
 
123
  )
124
  elapsed = time.time() - t0
125
  progress_queue.put(_SENTINEL)
126
+ return result, elapsed, step_count, target_emb
127
 
128
 
129
  def run_stage(
 
144
 
145
  stage_num = len(stage_results_state) + 1
146
 
 
 
 
 
 
147
  # Build prompt
148
  if stage_num == 1:
149
  prompt = _STAGE1_PROMPT
 
151
  prev_text = stage_results_state[-1]["text"]
152
  prompt = _STAGE2_PROMPT_TEMPLATE.format(seed=prev_text)
153
 
154
+ # On Stage 1, pass raw text so encoding happens inside GPU context
155
+ encode_input = text.strip() if stage_num == 1 else None
156
+
157
  # Run beam search in a thread so we can yield progress
158
  progress_q = queue.Queue()
159
 
160
  # Container for the thread's return value
161
+ result_holder = [None, 0.0, 0, None]
162
 
163
  def _worker():
164
+ r, elapsed, steps, emb = _run_beam_search_threaded(
165
  target_emb_state, encoder_name, prompt,
166
  beam_width, top_k, patience, max_steps, min_similarity, randomness,
167
  progress_q,
168
+ encode_text_input=encode_input,
169
  )
170
  result_holder[0] = r
171
  result_holder[1] = elapsed
172
  result_holder[2] = steps
173
+ result_holder[3] = emb
174
 
175
  worker = threading.Thread(target=_worker)
176
  worker.start()
 
199
 
200
  worker.join()
201
 
202
+ result, elapsed, steps, returned_emb = result_holder
203
+ if returned_emb is not None:
204
+ target_emb_state = returned_emb
205
  stage_results_state = stage_results_state + [{
206
  "stage": stage_num,
207
  "text": result.seq_str,