onkarsus13 commited on
Commit
4d75ac0
·
verified ·
1 Parent(s): 6a767d1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +84 -1
README.md CHANGED
@@ -8,5 +8,88 @@ For Pretraining we used Atlas Dataset From John Hopkins
8
  Code for generation:
9
 
10
  ```
11
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ```
 
8
  Code for generation:
9
 
10
  ```
11
+ from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel, UniPCMultistepScheduler, LDMSuperResolutionPipeline
12
+ import torch
13
+ from PIL import Image
14
+ import numpy as np
15
+ from transformers import T5Tokenizer
16
+ import torch.nn as nn
17
+ import os
18
+
19
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
20
+
21
+
22
+ class_dict_atlas = {
23
+ 0:(0, 0, 0),
24
+ 1:(255, 60, 0),
25
+ 2:(255, 60, 232),
26
+ 3:(134, 79, 117),
27
+ 4:(125, 0, 190),
28
+ 5:(117, 200, 191),
29
+ 6:(230, 91, 101),
30
+ 7:(255, 0, 155),
31
+ 8:(75, 205, 155),
32
+ 9:(100, 37, 200)
33
+ }
34
+
35
+ name_class_dict = {
36
+ 0:"background",
37
+ 1:"aorta",
38
+ 2:"kidney_left",
39
+ 3:"liver",
40
+ 4:"postcava",
41
+ 5:"stomach",
42
+ 6:"gall_bladder",
43
+ 7:"kidney_right",
44
+ 8:"pancreas",
45
+ 9:"spleen"
46
+ }
47
+
48
+ def rgb_to_onehot(rgb_arr, color_dict=class_dict_atlas):
49
+ num_classes = len(color_dict)
50
+ shape = rgb_arr.shape[:2]+(num_classes,)
51
+ arr = np.zeros( shape, dtype=np.int8 )
52
+ for i, cls in enumerate(color_dict):
53
+ arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2])
54
+ return arr
55
+
56
+
57
+
58
+ pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
59
+ "onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas", torch_dtype=torch.float16, safety_checker=None,
60
+ feature_extractor=None,
61
+ )
62
+
63
+
64
+ pipe.tokenizer_3 = T5Tokenizer.from_pretrained(
65
+ "onkarsus13/Semantic-Control-Stable-diffusion-3-M-Mask2CT-Atlas",
66
+ subfolder='tokenizer_3'
67
+ )
68
+
69
+ pipe.to('cuda')
70
+ pipe.enable_model_cpu_offload()
71
+
72
+
73
+ generator = torch.Generator(device="cuda").manual_seed(1)
74
+ images = Image.open("<Give mask image for semantic guidance>")
75
+ shape = images.size
76
+
77
+ npi = np.asarray(images.convert("RGB"))
78
+ npi = rgb_to_onehot(npi, ).argmax(-1)
79
+ unique_ids = np.unique(npi)
80
+
81
+ print('CT image containg '+" ".join([name_class_dict[i] for i in unique_ids]))
82
+
83
+ image = pipe(
84
+ prompt='CT image containg '+" ".join([name_class_dict[i] for i in unique_ids]),
85
+ control_image=images.convert('RGB'),
86
+ height=128,
87
+ width=128,
88
+ num_inference_steps=50,
89
+ generator=generator,
90
+ controlnet_conditioning_scale=1.0,
91
+ ).images[0]
92
+
93
+ image.resize(shape).save('result.png')
94
+
95
  ```