Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +14 -66
- inference.py +105 -0
app.py
CHANGED
|
@@ -9,67 +9,12 @@ import gradio as gr
|
|
| 9 |
from PIL import Image
|
| 10 |
import torch
|
| 11 |
import torchvision.transforms as T
|
|
|
|
| 12 |
|
| 13 |
-
# --- 1) IMPORTER VOTRE CODE D'INFERENCE ---
|
| 14 |
-
# Par exemple, si vous avez un fichier inference.py qui définit une fonction `infer_wood(fibers_map, rings_map)`
|
| 15 |
-
# vous pouvez faire :
|
| 16 |
-
#
|
| 17 |
-
# from inference import infer_wood
|
| 18 |
-
#
|
| 19 |
-
# Et vous assurez que `infer_wood` prend en entrée deux objets PIL.Image
|
| 20 |
-
# (cartographie fibres et cartographie cernes) et renvoie une PIL.Image résultat.
|
| 21 |
-
#
|
| 22 |
-
# Si vous n'avez pas encore ce fichier, créez une fonction de type :
|
| 23 |
|
| 24 |
-
# def infer_wood(fibers_img: Image.Image, rings_img: Image.Image) -> Image.Image:
|
| 25 |
-
# """
|
| 26 |
-
# Exemple de squelette de fonction d'inférence.
|
| 27 |
-
# -> Remplacez tout ce qui est à l'intérieur par votre propre pipeline (prétraitement, appel du modèle, post-traitement).
|
| 28 |
-
# """
|
| 29 |
-
# # --- Pré-traitement (adapté à votre modèle) ---
|
| 30 |
-
# # Par exemple :
|
| 31 |
-
# preprocess = T.Compose([
|
| 32 |
-
# T.Resize((256, 256)),
|
| 33 |
-
# T.ToTensor(),
|
| 34 |
-
# # T.Normalize(mean=[...], std=[...]) # si votre modèle a été entraîné avec normalisation
|
| 35 |
-
# ])
|
| 36 |
-
# x1 = preprocess(fibers_img).unsqueeze(0).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
| 37 |
-
# x2 = preprocess(rings_img).unsqueeze(0).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
| 38 |
-
|
| 39 |
-
# # --- Chargement / usage du modèle (ici, c'est un exemple générique) ---
|
| 40 |
-
# # Imaginons que vous aviez déjà chargé votre modèle quelque part globalement :
|
| 41 |
-
# # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 42 |
-
# # model = VotreModeleWood().to(device)
|
| 43 |
-
# # model.load_state_dict(torch.load('chemin/vers/votre_modele.pth', map_location=device))
|
| 44 |
-
# # model.eval()
|
| 45 |
-
# #
|
| 46 |
-
# # Ici, on concatène les deux cartes pour former l’entrée (adaptez selon votre archi).
|
| 47 |
-
|
| 48 |
-
# input_tensor = torch.cat([x1, x2], dim=1) # par exemple (1, C1+C2, H, W)
|
| 49 |
-
|
| 50 |
-
# with torch.no_grad():
|
| 51 |
-
# pred_tensor = model(input_tensor) # suppose que `model` est déjà défini globalement et chargé
|
| 52 |
-
|
| 53 |
-
# # --- Post-traitement pour revenir à PIL.Image ---
|
| 54 |
-
# postprocess = T.ToPILImage()
|
| 55 |
-
# output_img = postprocess(pred_tensor.squeeze(0).cpu().clamp(0, 1))
|
| 56 |
-
# return output_img
|
| 57 |
-
|
| 58 |
-
def infer_wood(fibers_img: Image.Image, rings_img: Image.Image):
|
| 59 |
-
return rings_img
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# --- 2) CHARGEMENT GLOBAL DU MODÈLE (optionnel) ---
|
| 63 |
-
# Vous pouvez charger votre modèle une seule fois, ici, en dehors de la fonction infer_wood,
|
| 64 |
-
# afin que Gradio ne fasse pas recharger à chaque appel. Par exemple :
|
| 65 |
|
| 66 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 67 |
|
| 68 |
-
# Exemple :
|
| 69 |
-
# from models.votre_modele import VotreModeleWood
|
| 70 |
-
# model = VotreModeleWood().to(device)
|
| 71 |
-
# model.load_state_dict(torch.load('models/chemin_du_modele.pth', map_location=device))
|
| 72 |
-
# model.eval()
|
| 73 |
|
| 74 |
|
| 75 |
# --- 3) FONCTION GRADIO D’INTERFACE ---
|
|
@@ -81,8 +26,11 @@ def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Im
|
|
| 81 |
# Vérifier que les deux images sont bien en mode RGB (ou adapter si besoin)
|
| 82 |
fibers_map = fibers_map.convert("RGB")
|
| 83 |
rings_map = rings_map.convert("RGB")
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
| 86 |
return result_img
|
| 87 |
|
| 88 |
|
|
@@ -90,17 +38,17 @@ def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Im
|
|
| 90 |
iface = gr.Interface(
|
| 91 |
fn=gradio_generate,
|
| 92 |
inputs=[
|
| 93 |
-
gr.Image(type="pil", label="
|
| 94 |
-
gr.Image(type="pil", label="
|
| 95 |
],
|
| 96 |
-
outputs=gr.Image(type="pil", label="
|
| 97 |
-
title="
|
| 98 |
description="""
|
| 99 |
-
|
| 100 |
-
1)
|
| 101 |
-
2)
|
| 102 |
|
| 103 |
-
|
| 104 |
"""
|
| 105 |
)
|
| 106 |
|
|
|
|
| 9 |
from PIL import Image
|
| 10 |
import torch
|
| 11 |
import torchvision.transforms as T
|
| 12 |
+
from inference import inference
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
# --- 3) FONCTION GRADIO D’INTERFACE ---
|
|
|
|
| 26 |
# Vérifier que les deux images sont bien en mode RGB (ou adapter si besoin)
|
| 27 |
fibers_map = fibers_map.convert("RGB")
|
| 28 |
rings_map = rings_map.convert("RGB")
|
| 29 |
+
|
| 30 |
+
model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt/model_LR"
|
| 31 |
+
|
| 32 |
+
result_img = inference(model_id, device, rings_map, fibers_map)
|
| 33 |
+
|
| 34 |
return result_img
|
| 35 |
|
| 36 |
|
|
|
|
| 38 |
iface = gr.Interface(
|
| 39 |
fn=gradio_generate,
|
| 40 |
inputs=[
|
| 41 |
+
gr.Image(type="pil", label="Fibre orientation map"),
|
| 42 |
+
gr.Image(type="pil", label="Growth ring map")
|
| 43 |
],
|
| 44 |
+
outputs=gr.Image(type="pil", label="Photorealistic wood generated"),
|
| 45 |
+
title="Photorealistic wood generator",
|
| 46 |
description="""
|
| 47 |
+
Upload :
|
| 48 |
+
1) a fibre orientation mapping image,
|
| 49 |
+
2) a tree-ring boundary mapping image.
|
| 50 |
|
| 51 |
+
The model will return a photo-realistic rendering of the wood that you can download.
|
| 52 |
"""
|
| 53 |
)
|
| 54 |
|
inference.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Wed Jun 11 09:51:38 2025
|
| 4 |
+
|
| 5 |
+
@author: camaac
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import PIL
|
| 9 |
+
import torch
|
| 10 |
+
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class UNetNoCondWrapper(nn.Module):
|
| 17 |
+
def __init__(self, base_unet: UNet2DModel):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.unet = base_unet
|
| 20 |
+
|
| 21 |
+
def forward(
|
| 22 |
+
self,
|
| 23 |
+
sample,
|
| 24 |
+
timestep,
|
| 25 |
+
encoder_hidden_states=None,
|
| 26 |
+
added_cond_kwargs=None,
|
| 27 |
+
cross_attention_kwargs=None,
|
| 28 |
+
return_dict=False,
|
| 29 |
+
**kwargs
|
| 30 |
+
):
|
| 31 |
+
|
| 32 |
+
return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
|
| 33 |
+
|
| 34 |
+
def __getattr__(self, name):
|
| 35 |
+
if name in ("unet", "forward", "__getstate__", "__setstate__"):
|
| 36 |
+
return super().__getattr__(name)
|
| 37 |
+
return getattr(self.unet, name)
|
| 38 |
+
|
| 39 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 40 |
+
# délègue à la vraie instance UNet2DModel
|
| 41 |
+
return self.unet.save_pretrained(save_directory, **kwargs)
|
| 42 |
+
|
| 43 |
+
def inference(model_id,device, img1, img2):
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
vae = AutoencoderKL.from_pretrained(f"{model_id}/vae").to(device)
|
| 47 |
+
scheduler = DDPMScheduler.from_pretrained(f"{model_id}/scheduler")
|
| 48 |
+
tokenizer = CLIPTokenizer.from_pretrained(f"{model_id}/tokenizer")
|
| 49 |
+
text_encoder = CLIPTextModel.from_pretrained(f"{model_id}/text_encoder").to(device)
|
| 50 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(f"{model_id}/feature_extractor")
|
| 51 |
+
|
| 52 |
+
# 2) Chargez votre UNet non‑conditionné et wrappez‑le
|
| 53 |
+
base_unet = UNet2DModel.from_pretrained(f"{model_id}/unet").to(device)
|
| 54 |
+
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
|
| 55 |
+
|
| 56 |
+
# 3) Construisez la pipeline manuellement
|
| 57 |
+
pipe = StableDiffusionInstructPix2PixPipeline(
|
| 58 |
+
vae=vae,
|
| 59 |
+
text_encoder=text_encoder,
|
| 60 |
+
tokenizer=tokenizer,
|
| 61 |
+
unet=wrapped_unet,
|
| 62 |
+
scheduler=scheduler,
|
| 63 |
+
safety_checker=None,
|
| 64 |
+
feature_extractor=feature_extractor,
|
| 65 |
+
)
|
| 66 |
+
pipe = pipe.to(torch.float16).to(device)
|
| 67 |
+
|
| 68 |
+
generator = torch.Generator("cuda").manual_seed(0)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
img1 = img1.resize((512, 512))
|
| 72 |
+
img2 = img2.resize((512, 512))
|
| 73 |
+
|
| 74 |
+
img1_np = np.array(img1)
|
| 75 |
+
if len(img1_np.shape) > 2:
|
| 76 |
+
img1_np = img1_np[:, :, 0]
|
| 77 |
+
|
| 78 |
+
img2_np = np.array(img2)
|
| 79 |
+
if len(img2_np.shape) > 2:
|
| 80 |
+
img2_np = img2_np[:, :, 0]
|
| 81 |
+
|
| 82 |
+
img1_np[img1_np > 200] = 255
|
| 83 |
+
img1_np[img1_np <= 200] = 0
|
| 84 |
+
img1_np = 255-img1_np
|
| 85 |
+
img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
|
| 86 |
+
|
| 87 |
+
image = PIL.Image.fromarray(img_np)
|
| 88 |
+
image = PIL.ImageOps.exif_transpose(image)
|
| 89 |
+
|
| 90 |
+
num_inference_steps = 20
|
| 91 |
+
image_guidance_scale = 1.9
|
| 92 |
+
guidance_scale = 10
|
| 93 |
+
|
| 94 |
+
edited_image = pipe(
|
| 95 |
+
prompt=[""] ,
|
| 96 |
+
image=image,
|
| 97 |
+
num_inference_steps=num_inference_steps,
|
| 98 |
+
image_guidance_scale=image_guidance_scale,
|
| 99 |
+
guidance_scale=guidance_scale,
|
| 100 |
+
generator=generator,
|
| 101 |
+
safety_checker=None,
|
| 102 |
+
num_images_per_prompt=1
|
| 103 |
+
).images
|
| 104 |
+
|
| 105 |
+
return edited_image
|