multimodalart HF Staff commited on
Commit
4c7e7bd
·
verified ·
1 Parent(s): e61c197

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -94,6 +94,7 @@ async def set_collider(request: Request):
94
  "onsetmode": bool(body.get("onsetmode", False)),
95
  "reset": int(body.get("reset", 0)),
96
  "seed": int(body.get("seed", 0)),
 
97
  })
98
  return {"ok": True}
99
 
@@ -118,6 +119,12 @@ async def set_audio(request: Request):
118
  return {"ok": True}
119
 
120
 
 
 
 
 
 
 
121
  @app.api(name="stream")
122
  @spaces.GPU(duration=90)
123
  def stream(session_id: str) -> str:
@@ -141,6 +148,7 @@ def stream(session_id: str) -> str:
141
  cur_reset = cur_seed = 0
142
  force_reenc = False
143
  txt_cache, aud_cache = {}, {}
 
144
  while time.time() - t0 < 55.0:
145
  c = read_slot(session_id)
146
  if c is None:
@@ -157,6 +165,19 @@ def stream(session_id: str) -> str:
157
  if seed != cur_seed:
158
  cur_seed = seed
159
  gen = torch.Generator(device=dev).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  toks = []
161
  for _ in range(10): # per-frame slot read -> steering applies ~instantly
162
  c = read_slot(session_id) or c
 
94
  "onsetmode": bool(body.get("onsetmode", False)),
95
  "reset": int(body.get("reset", 0)),
96
  "seed": int(body.get("seed", 0)),
97
+ "bank_op": body.get("bank_op", prev.get("bank_op")),
98
  })
99
  return {"ok": True}
100
 
 
119
  return {"ok": True}
120
 
121
 
122
+ @app.get("/banks")
123
+ async def banks(session_id: str = ""):
124
+ sid = os.path.basename(session_id)
125
+ return {"bankStatus": [os.path.exists(os.path.join(SESSION_DIR, f"{sid}_bank{i}.pt")) for i in range(3)]}
126
+
127
+
128
  @app.api(name="stream")
129
  @spaces.GPU(duration=90)
130
  def stream(session_id: str) -> str:
 
148
  cur_reset = cur_seed = 0
149
  force_reenc = False
150
  txt_cache, aud_cache = {}, {}
151
+ cur_bank_ver = 0
152
  while time.time() - t0 < 55.0:
153
  c = read_slot(session_id)
154
  if c is None:
 
165
  if seed != cur_seed:
166
  cur_seed = seed
167
  gen = torch.Generator(device=dev).manual_seed(seed)
168
+ bop = c.get("bank_op")
169
+ if bop and int(bop.get("ver", 0)) != cur_bank_ver: # save/recall generation state
170
+ cur_bank_ver = int(bop.get("ver", 0))
171
+ bpath = os.path.join(SESSION_DIR, f"{os.path.basename(session_id)}_bank{int(bop.get('idx', 0))}.pt")
172
+ try:
173
+ if bop.get("action") == "save":
174
+ torch.save({"dstate": dstate, "history": history, "emitted": emitted}, bpath)
175
+ elif bop.get("action") == "load" and os.path.exists(bpath):
176
+ d = torch.load(bpath, map_location=dev)
177
+ dstate, history, emitted = d["dstate"], d["history"].to(dev), int(d["emitted"])
178
+ source, cur_sig = None, None
179
+ except Exception as e:
180
+ print("[bank] error:", repr(e), flush=True)
181
  toks = []
182
  for _ in range(10): # per-frame slot read -> steering applies ~instantly
183
  c = read_slot(session_id) or c