CarolineM5 commited on
Commit
76bee53
·
verified ·
1 Parent(s): c4a50be

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +14 -66
  2. inference.py +105 -0
app.py CHANGED
@@ -9,67 +9,12 @@ import gradio as gr
9
  from PIL import Image
10
  import torch
11
  import torchvision.transforms as T
 
12
 
13
- # --- 1) IMPORTER VOTRE CODE D'INFERENCE ---
14
- # Par exemple, si vous avez un fichier inference.py qui définit une fonction `infer_wood(fibers_map, rings_map)`
15
- # vous pouvez faire :
16
- #
17
- # from inference import infer_wood
18
- #
19
- # Et vous assurez que `infer_wood` prend en entrée deux objets PIL.Image
20
- # (cartographie fibres et cartographie cernes) et renvoie une PIL.Image résultat.
21
- #
22
- # Si vous n'avez pas encore ce fichier, créez une fonction de type :
23
 
24
- # def infer_wood(fibers_img: Image.Image, rings_img: Image.Image) -> Image.Image:
25
- # """
26
- # Exemple de squelette de fonction d'inférence.
27
- # -> Remplacez tout ce qui est à l'intérieur par votre propre pipeline (prétraitement, appel du modèle, post-traitement).
28
- # """
29
- # # --- Pré-traitement (adapté à votre modèle) ---
30
- # # Par exemple :
31
- # preprocess = T.Compose([
32
- # T.Resize((256, 256)),
33
- # T.ToTensor(),
34
- # # T.Normalize(mean=[...], std=[...]) # si votre modèle a été entraîné avec normalisation
35
- # ])
36
- # x1 = preprocess(fibers_img).unsqueeze(0).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
37
- # x2 = preprocess(rings_img).unsqueeze(0).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
38
-
39
- # # --- Chargement / usage du modèle (ici, c'est un exemple générique) ---
40
- # # Imaginons que vous aviez déjà chargé votre modèle quelque part globalement :
41
- # # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
- # # model = VotreModeleWood().to(device)
43
- # # model.load_state_dict(torch.load('chemin/vers/votre_modele.pth', map_location=device))
44
- # # model.eval()
45
- # #
46
- # # Ici, on concatène les deux cartes pour former l’entrée (adaptez selon votre archi).
47
-
48
- # input_tensor = torch.cat([x1, x2], dim=1) # par exemple (1, C1+C2, H, W)
49
-
50
- # with torch.no_grad():
51
- # pred_tensor = model(input_tensor) # suppose que `model` est déjà défini globalement et chargé
52
-
53
- # # --- Post-traitement pour revenir à PIL.Image ---
54
- # postprocess = T.ToPILImage()
55
- # output_img = postprocess(pred_tensor.squeeze(0).cpu().clamp(0, 1))
56
- # return output_img
57
-
58
- def infer_wood(fibers_img: Image.Image, rings_img: Image.Image):
59
- return rings_img
60
-
61
-
62
- # --- 2) CHARGEMENT GLOBAL DU MODÈLE (optionnel) ---
63
- # Vous pouvez charger votre modèle une seule fois, ici, en dehors de la fonction infer_wood,
64
- # afin que Gradio ne fasse pas recharger à chaque appel. Par exemple :
65
 
66
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
 
68
- # Exemple :
69
- # from models.votre_modele import VotreModeleWood
70
- # model = VotreModeleWood().to(device)
71
- # model.load_state_dict(torch.load('models/chemin_du_modele.pth', map_location=device))
72
- # model.eval()
73
 
74
 
75
  # --- 3) FONCTION GRADIO D’INTERFACE ---
@@ -81,8 +26,11 @@ def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Im
81
  # Vérifier que les deux images sont bien en mode RGB (ou adapter si besoin)
82
  fibers_map = fibers_map.convert("RGB")
83
  rings_map = rings_map.convert("RGB")
84
- # Appel de votre code d'inférence
85
- result_img = infer_wood(fibers_map, rings_map)
 
 
 
86
  return result_img
87
 
88
 
@@ -90,17 +38,17 @@ def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Im
90
  iface = gr.Interface(
91
  fn=gradio_generate,
92
  inputs=[
93
- gr.Image(type="pil", label="Cartographie d’orientation des fibres"),
94
- gr.Image(type="pil", label="Cartographie des limites de cernes")
95
  ],
96
- outputs=gr.Image(type="pil", label="Bois photoréalistique généré"),
97
- title="Générateur de bois photoréalistique",
98
  description="""
99
- Téléversez :
100
- 1) une image de cartographie d’orientation des fibres,
101
- 2) une image de cartographie des limites de cernes.
102
 
103
- Le modèle renverra un rendu photoréalistique de bois que vous pouvez télécharger.
104
  """
105
  )
106
 
 
9
  from PIL import Image
10
  import torch
11
  import torchvision.transforms as T
12
+ from inference import inference
13
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
 
 
 
 
 
18
 
19
 
20
  # --- 3) FONCTION GRADIO D’INTERFACE ---
 
26
  # Vérifier que les deux images sont bien en mode RGB (ou adapter si besoin)
27
  fibers_map = fibers_map.convert("RGB")
28
  rings_map = rings_map.convert("RGB")
29
+
30
+ model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt/model_LR"
31
+
32
+ result_img = inference(model_id, device, rings_map, fibers_map)
33
+
34
  return result_img
35
 
36
 
 
38
  iface = gr.Interface(
39
  fn=gradio_generate,
40
  inputs=[
41
+ gr.Image(type="pil", label="Fibre orientation map"),
42
+ gr.Image(type="pil", label="Growth ring map")
43
  ],
44
+ outputs=gr.Image(type="pil", label="Photorealistic wood generated"),
45
+ title="Photorealistic wood generator",
46
  description="""
47
+ Upload :
48
+ 1) a fibre orientation mapping image,
49
+ 2) a tree-ring boundary mapping image.
50
 
51
+ The model will return a photo-realistic rendering of the wood that you can download.
52
  """
53
  )
54
 
inference.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Jun 11 09:51:38 2025
4
+
5
+ @author: camaac
6
+ """
7
+
8
+ import PIL
9
+ import torch
10
+ from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
11
+ import numpy as np
12
+ import torch.nn as nn
13
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
+
15
+
16
+ class UNetNoCondWrapper(nn.Module):
17
+ def __init__(self, base_unet: UNet2DModel):
18
+ super().__init__()
19
+ self.unet = base_unet
20
+
21
+ def forward(
22
+ self,
23
+ sample,
24
+ timestep,
25
+ encoder_hidden_states=None,
26
+ added_cond_kwargs=None,
27
+ cross_attention_kwargs=None,
28
+ return_dict=False,
29
+ **kwargs
30
+ ):
31
+
32
+ return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
33
+
34
+ def __getattr__(self, name):
35
+ if name in ("unet", "forward", "__getstate__", "__setstate__"):
36
+ return super().__getattr__(name)
37
+ return getattr(self.unet, name)
38
+
39
+ def save_pretrained(self, save_directory, **kwargs):
40
+ # délègue à la vraie instance UNet2DModel
41
+ return self.unet.save_pretrained(save_directory, **kwargs)
42
+
43
+ def inference(model_id,device, img1, img2):
44
+
45
+
46
+ vae = AutoencoderKL.from_pretrained(f"{model_id}/vae").to(device)
47
+ scheduler = DDPMScheduler.from_pretrained(f"{model_id}/scheduler")
48
+ tokenizer = CLIPTokenizer.from_pretrained(f"{model_id}/tokenizer")
49
+ text_encoder = CLIPTextModel.from_pretrained(f"{model_id}/text_encoder").to(device)
50
+ feature_extractor = CLIPImageProcessor.from_pretrained(f"{model_id}/feature_extractor")
51
+
52
+ # 2) Chargez votre UNet non‑conditionné et wrappez‑le
53
+ base_unet = UNet2DModel.from_pretrained(f"{model_id}/unet").to(device)
54
+ wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
55
+
56
+ # 3) Construisez la pipeline manuellement
57
+ pipe = StableDiffusionInstructPix2PixPipeline(
58
+ vae=vae,
59
+ text_encoder=text_encoder,
60
+ tokenizer=tokenizer,
61
+ unet=wrapped_unet,
62
+ scheduler=scheduler,
63
+ safety_checker=None,
64
+ feature_extractor=feature_extractor,
65
+ )
66
+ pipe = pipe.to(torch.float16).to(device)
67
+
68
+ generator = torch.Generator("cuda").manual_seed(0)
69
+
70
+
71
+ img1 = img1.resize((512, 512))
72
+ img2 = img2.resize((512, 512))
73
+
74
+ img1_np = np.array(img1)
75
+ if len(img1_np.shape) > 2:
76
+ img1_np = img1_np[:, :, 0]
77
+
78
+ img2_np = np.array(img2)
79
+ if len(img2_np.shape) > 2:
80
+ img2_np = img2_np[:, :, 0]
81
+
82
+ img1_np[img1_np > 200] = 255
83
+ img1_np[img1_np <= 200] = 0
84
+ img1_np = 255-img1_np
85
+ img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
86
+
87
+ image = PIL.Image.fromarray(img_np)
88
+ image = PIL.ImageOps.exif_transpose(image)
89
+
90
+ num_inference_steps = 20
91
+ image_guidance_scale = 1.9
92
+ guidance_scale = 10
93
+
94
+ edited_image = pipe(
95
+ prompt=[""] ,
96
+ image=image,
97
+ num_inference_steps=num_inference_steps,
98
+ image_guidance_scale=image_guidance_scale,
99
+ guidance_scale=guidance_scale,
100
+ generator=generator,
101
+ safety_checker=None,
102
+ num_images_per_prompt=1
103
+ ).images
104
+
105
+ return edited_image