Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -45,7 +45,8 @@ import spaces
|
|
| 45 |
start_time = time.time()
|
| 46 |
|
| 47 |
####################### Setup Model
|
| 48 |
-
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
|
|
|
|
| 49 |
from transformers import CLIPTextModel
|
| 50 |
from huggingface_hub import hf_hub_download
|
| 51 |
from safetensors.torch import load_file
|
|
@@ -55,7 +56,6 @@ import uuid
|
|
| 55 |
import av
|
| 56 |
|
| 57 |
def write_video(file_name, images, fps=17):
|
| 58 |
-
print('Saving')
|
| 59 |
container = av.open(file_name, mode="w")
|
| 60 |
|
| 61 |
stream = container.add_stream("h264", rate=fps)
|
|
@@ -76,7 +76,6 @@ def write_video(file_name, images, fps=17):
|
|
| 76 |
container.mux(packet)
|
| 77 |
# Close the file
|
| 78 |
container.close()
|
| 79 |
-
print('Saved')
|
| 80 |
|
| 81 |
def imio_write_video(file_name, images, fps=15):
|
| 82 |
writer = imageio.get_writer(file_name, fps=fps)
|
|
@@ -128,14 +127,11 @@ pipe.to(device=DEVICE)
|
|
| 128 |
|
| 129 |
@spaces.GPU()
|
| 130 |
def generate_gpu(in_im_embs):
|
| 131 |
-
print('start gen')
|
| 132 |
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
| 133 |
output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
| 134 |
-
print('image is made')
|
| 135 |
im_emb, _ = pipe.encode_image(
|
| 136 |
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
| 137 |
)
|
| 138 |
-
print('im_emb is made')
|
| 139 |
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
| 140 |
return output, im_emb
|
| 141 |
|
|
@@ -168,7 +164,6 @@ def get_user_emb(embs, ys):
|
|
| 168 |
embs.append(.01*torch.randn(1280))
|
| 169 |
ys.append(0)
|
| 170 |
ys.append(1)
|
| 171 |
-
print('Fixing only one feedback class available.\n')
|
| 172 |
|
| 173 |
indices = list(range(len(embs)))
|
| 174 |
# sample only as many negatives as there are positives
|
|
@@ -177,14 +172,12 @@ def get_user_emb(embs, ys):
|
|
| 177 |
#lower = min(len(pos_indices), len(neg_indices))
|
| 178 |
#neg_indices = random.sample(neg_indices, lower)
|
| 179 |
#pos_indices = random.sample(pos_indices, lower)
|
| 180 |
-
print(len(neg_indices), len(pos_indices))
|
| 181 |
|
| 182 |
|
| 183 |
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
|
| 184 |
# this ends up adding a rating but losing an embedding, it seems.
|
| 185 |
# let's take off a rating if so to continue without indexing errors.
|
| 186 |
if len(ys) > len(embs):
|
| 187 |
-
print('ys are longer than embs; popping latest rating')
|
| 188 |
ys.pop(-1)
|
| 189 |
|
| 190 |
feature_embs = np.array(torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu'))
|
|
@@ -192,12 +185,10 @@ def get_user_emb(embs, ys):
|
|
| 192 |
#feature_embs = scaler.transform(feature_embs)
|
| 193 |
chosen_y = np.array([ys[i] for i in indices])
|
| 194 |
|
| 195 |
-
print('Gathering coefficients')
|
| 196 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
| 197 |
lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
|
| 198 |
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu')
|
| 199 |
coef_ = coef_ / coef_.abs().max() * 3
|
| 200 |
-
print('Gathered')
|
| 201 |
|
| 202 |
w = 1# if len(embs) % 2 == 0 else 0
|
| 203 |
im_emb = w * coef_.to(dtype=dtype)
|
|
@@ -205,7 +196,6 @@ def get_user_emb(embs, ys):
|
|
| 205 |
|
| 206 |
|
| 207 |
def pluck_img(user_id, user_emb):
|
| 208 |
-
print(user_id, 'user_id')
|
| 209 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 210 |
while len(not_rated_rows) == 0:
|
| 211 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
|
@@ -231,7 +221,6 @@ def background_next_image():
|
|
| 231 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 232 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
| 233 |
time.sleep(.01)
|
| 234 |
-
print('all users have 4 or less rows rated')
|
| 235 |
|
| 236 |
user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
|
| 237 |
for uid in user_id_list:
|
|
@@ -253,15 +242,12 @@ def background_next_image():
|
|
| 253 |
continue
|
| 254 |
|
| 255 |
if len(rated_rows) < 4:
|
| 256 |
-
print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
|
| 257 |
continue
|
| 258 |
|
| 259 |
-
print(uid)
|
| 260 |
embs, ys = pluck_embs_ys(uid)
|
| 261 |
|
| 262 |
user_emb = get_user_emb(embs, ys)
|
| 263 |
img, embs = generate(user_emb)
|
| 264 |
-
print(img)
|
| 265 |
if img:
|
| 266 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
| 267 |
tmp_df['paths'] = [img]
|
|
@@ -276,16 +262,10 @@ def background_next_image():
|
|
| 276 |
cands['sum_bad_ratings'] = [sum([int(t==0) for t in i.values()]) for i in cands['user:rating']]
|
| 277 |
worst_row = cands.loc[cands['sum_bad_ratings']==cands['sum_bad_ratings'].max()].iloc[0]
|
| 278 |
worst_path = worst_row['paths']
|
| 279 |
-
print('Removing worst row:', worst_row, 'from prevs_df of len', len(prevs_df))
|
| 280 |
if os.path.isfile(worst_path):
|
| 281 |
os.remove(worst_path)
|
| 282 |
-
else:
|
| 283 |
-
# If it fails, inform the user.
|
| 284 |
-
print("Error: %s file not found" % worst_path)
|
| 285 |
-
|
| 286 |
# only keep x images & embeddings & ips, then remove the most often disliked besides calibrating
|
| 287 |
prevs_df = prevs_df[prevs_df['paths'] != worst_path]
|
| 288 |
-
print('prevs_df is now length:', len(prevs_df))
|
| 289 |
|
| 290 |
def pluck_embs_ys(user_id):
|
| 291 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
|
@@ -298,21 +278,17 @@ def pluck_embs_ys(user_id):
|
|
| 298 |
|
| 299 |
embs = rated_rows['embeddings'].to_list()
|
| 300 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
| 301 |
-
print('embs', 'ys', embs, ys)
|
| 302 |
return embs, ys
|
| 303 |
|
| 304 |
def next_image(calibrate_prompts, user_id):
|
| 305 |
-
print(prevs_df)
|
| 306 |
|
| 307 |
with torch.no_grad():
|
| 308 |
if len(calibrate_prompts) > 0:
|
| 309 |
-
print('######### Calibrating with sample media #########')
|
| 310 |
cal_video = calibrate_prompts.pop(0)
|
| 311 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 312 |
|
| 313 |
return image, calibrate_prompts
|
| 314 |
else:
|
| 315 |
-
print('######### Roaming #########')
|
| 316 |
embs, ys = pluck_embs_ys(user_id)
|
| 317 |
user_emb = get_user_emb(embs, ys)
|
| 318 |
image = pluck_img(user_id, user_emb)
|
|
@@ -355,7 +331,6 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
| 355 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
| 356 |
# TODO skip allowing rating & just continue
|
| 357 |
if img == None:
|
| 358 |
-
print('NSFW -- choice is disliked')
|
| 359 |
choice = 0
|
| 360 |
|
| 361 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
|
@@ -425,7 +400,6 @@ with gr.Blocks(css=css, head=js_head) as demo:
|
|
| 425 |
Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
|
| 426 |
''', elem_id="description")
|
| 427 |
user_id = gr.State()
|
| 428 |
-
print('USER_ID: ',user_id)
|
| 429 |
# calibration videos -- this is a misnomer now :D
|
| 430 |
calibrate_prompts = gr.State([
|
| 431 |
'./first.mp4',
|
|
@@ -487,7 +461,7 @@ log = logging.getLogger('log_here')
|
|
| 487 |
log.setLevel(logging.ERROR)
|
| 488 |
|
| 489 |
scheduler = BackgroundScheduler()
|
| 490 |
-
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.
|
| 491 |
scheduler.start()
|
| 492 |
|
| 493 |
#thread = threading.Thread(target=background_next_image,)
|
|
|
|
| 45 |
start_time = time.time()
|
| 46 |
|
| 47 |
####################### Setup Model
|
| 48 |
+
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, utils
|
| 49 |
+
utils.logging.disable_progress_bar
|
| 50 |
from transformers import CLIPTextModel
|
| 51 |
from huggingface_hub import hf_hub_download
|
| 52 |
from safetensors.torch import load_file
|
|
|
|
| 56 |
import av
|
| 57 |
|
| 58 |
def write_video(file_name, images, fps=17):
|
|
|
|
| 59 |
container = av.open(file_name, mode="w")
|
| 60 |
|
| 61 |
stream = container.add_stream("h264", rate=fps)
|
|
|
|
| 76 |
container.mux(packet)
|
| 77 |
# Close the file
|
| 78 |
container.close()
|
|
|
|
| 79 |
|
| 80 |
def imio_write_video(file_name, images, fps=15):
|
| 81 |
writer = imageio.get_writer(file_name, fps=fps)
|
|
|
|
| 127 |
|
| 128 |
@spaces.GPU()
|
| 129 |
def generate_gpu(in_im_embs):
|
|
|
|
| 130 |
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
| 131 |
output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
|
|
|
| 132 |
im_emb, _ = pipe.encode_image(
|
| 133 |
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
| 134 |
)
|
|
|
|
| 135 |
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
| 136 |
return output, im_emb
|
| 137 |
|
|
|
|
| 164 |
embs.append(.01*torch.randn(1280))
|
| 165 |
ys.append(0)
|
| 166 |
ys.append(1)
|
|
|
|
| 167 |
|
| 168 |
indices = list(range(len(embs)))
|
| 169 |
# sample only as many negatives as there are positives
|
|
|
|
| 172 |
#lower = min(len(pos_indices), len(neg_indices))
|
| 173 |
#neg_indices = random.sample(neg_indices, lower)
|
| 174 |
#pos_indices = random.sample(pos_indices, lower)
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
|
| 178 |
# this ends up adding a rating but losing an embedding, it seems.
|
| 179 |
# let's take off a rating if so to continue without indexing errors.
|
| 180 |
if len(ys) > len(embs):
|
|
|
|
| 181 |
ys.pop(-1)
|
| 182 |
|
| 183 |
feature_embs = np.array(torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu'))
|
|
|
|
| 185 |
#feature_embs = scaler.transform(feature_embs)
|
| 186 |
chosen_y = np.array([ys[i] for i in indices])
|
| 187 |
|
|
|
|
| 188 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
| 189 |
lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
|
| 190 |
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu')
|
| 191 |
coef_ = coef_ / coef_.abs().max() * 3
|
|
|
|
| 192 |
|
| 193 |
w = 1# if len(embs) % 2 == 0 else 0
|
| 194 |
im_emb = w * coef_.to(dtype=dtype)
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
def pluck_img(user_id, user_emb):
|
|
|
|
| 199 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 200 |
while len(not_rated_rows) == 0:
|
| 201 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
|
|
|
| 221 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 222 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
| 223 |
time.sleep(.01)
|
|
|
|
| 224 |
|
| 225 |
user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
|
| 226 |
for uid in user_id_list:
|
|
|
|
| 242 |
continue
|
| 243 |
|
| 244 |
if len(rated_rows) < 4:
|
|
|
|
| 245 |
continue
|
| 246 |
|
|
|
|
| 247 |
embs, ys = pluck_embs_ys(uid)
|
| 248 |
|
| 249 |
user_emb = get_user_emb(embs, ys)
|
| 250 |
img, embs = generate(user_emb)
|
|
|
|
| 251 |
if img:
|
| 252 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
| 253 |
tmp_df['paths'] = [img]
|
|
|
|
| 262 |
cands['sum_bad_ratings'] = [sum([int(t==0) for t in i.values()]) for i in cands['user:rating']]
|
| 263 |
worst_row = cands.loc[cands['sum_bad_ratings']==cands['sum_bad_ratings'].max()].iloc[0]
|
| 264 |
worst_path = worst_row['paths']
|
|
|
|
| 265 |
if os.path.isfile(worst_path):
|
| 266 |
os.remove(worst_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
# only keep x images & embeddings & ips, then remove the most often disliked besides calibrating
|
| 268 |
prevs_df = prevs_df[prevs_df['paths'] != worst_path]
|
|
|
|
| 269 |
|
| 270 |
def pluck_embs_ys(user_id):
|
| 271 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
|
|
|
| 278 |
|
| 279 |
embs = rated_rows['embeddings'].to_list()
|
| 280 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
|
|
|
| 281 |
return embs, ys
|
| 282 |
|
| 283 |
def next_image(calibrate_prompts, user_id):
|
|
|
|
| 284 |
|
| 285 |
with torch.no_grad():
|
| 286 |
if len(calibrate_prompts) > 0:
|
|
|
|
| 287 |
cal_video = calibrate_prompts.pop(0)
|
| 288 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 289 |
|
| 290 |
return image, calibrate_prompts
|
| 291 |
else:
|
|
|
|
| 292 |
embs, ys = pluck_embs_ys(user_id)
|
| 293 |
user_emb = get_user_emb(embs, ys)
|
| 294 |
image = pluck_img(user_id, user_emb)
|
|
|
|
| 331 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
| 332 |
# TODO skip allowing rating & just continue
|
| 333 |
if img == None:
|
|
|
|
| 334 |
choice = 0
|
| 335 |
|
| 336 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
|
|
|
| 400 |
Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
|
| 401 |
''', elem_id="description")
|
| 402 |
user_id = gr.State()
|
|
|
|
| 403 |
# calibration videos -- this is a misnomer now :D
|
| 404 |
calibrate_prompts = gr.State([
|
| 405 |
'./first.mp4',
|
|
|
|
| 461 |
log.setLevel(logging.ERROR)
|
| 462 |
|
| 463 |
scheduler = BackgroundScheduler()
|
| 464 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.3)
|
| 465 |
scheduler.start()
|
| 466 |
|
| 467 |
#thread = threading.Thread(target=background_next_image,)
|