CarolineM5 commited on
Commit
73160c6
·
verified ·
1 Parent(s): 8b336ad

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +12 -6
inference.py CHANGED
@@ -43,14 +43,20 @@ class UNetNoCondWrapper(nn.Module):
43
  def inference(model_id,device, img1, img2):
44
 
45
 
46
- vae = AutoencoderKL.from_pretrained(f"{model_id}/vae").to(device)
47
- scheduler = DDPMScheduler.from_pretrained(f"{model_id}/scheduler")
48
- tokenizer = CLIPTokenizer.from_pretrained(f"{model_id}/tokenizer")
49
- text_encoder = CLIPTextModel.from_pretrained(f"{model_id}/text_encoder").to(device)
50
- feature_extractor = CLIPImageProcessor.from_pretrained(f"{model_id}/feature_extractor")
 
 
 
 
 
 
51
 
52
  # 2) Chargez votre UNet non‑conditionné et wrappez‑le
53
- base_unet = UNet2DModel.from_pretrained(f"{model_id}/unet").to(device)
54
  wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
55
 
56
  # 3) Construisez la pipeline manuellement
 
43
  def inference(model_id,device, img1, img2):
44
 
45
 
46
+ # vae = AutoencoderKL.from_pretrained(f"{model_id}/vae").to(device)
47
+ # scheduler = DDPMScheduler.from_pretrained(f"{model_id}/scheduler")
48
+ # tokenizer = CLIPTokenizer.from_pretrained(f"{model_id}/tokenizer")
49
+ # text_encoder = CLIPTextModel.from_pretrained(f"{model_id}/text_encoder").to(device)
50
+ # feature_extractor = CLIPImageProcessor.from_pretrained(f"{model_id}/feature_extractor")
51
+
52
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
53
+ scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
54
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
55
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
56
+ feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
57
 
58
  # 2) Chargez votre UNet non‑conditionné et wrappez‑le
59
+ base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
60
  wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
61
 
62
  # 3) Construisez la pipeline manuellement