JarvisLabs commited on
Commit
766c4fc
·
verified ·
1 Parent(s): c1ca4d2

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/rep_api.py +380 -315
  2. src/utils.py +47 -0
src/rep_api.py CHANGED
@@ -1,315 +1,380 @@
1
- import replicate
2
- import os
3
- from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64
4
- from src.deepl import detect_and_translate
5
- import json
6
- import time
7
- style_json="model_dict.json"
8
- model_dict=json.load(open(style_json,"r"))
9
-
10
-
11
- def generate_image_control_net(prompt,lora_model,api_path,aspect_ratio,lora_scale,
12
- use_control_net,control_net_type,control_net_img,control_net_strength,
13
- num_outputs=1,guidance_scale=3.5,seed=None,
14
- ):
15
- print(prompt,lora_model,api_path,aspect_ratio,use_control_net)
16
- inputs = {
17
- "prompt": detect_and_translate(prompt),
18
- "output_format": "png",
19
- "num_outputs":num_outputs,
20
- "guidance_scale": guidance_scale,
21
- "output_quality": 100,
22
- }
23
- if seed is not None:
24
- inputs["seed"]=seed
25
-
26
-
27
- if use_control_net:
28
- api_path= "xlabs-ai/flux-dev-controlnet:f2c31c31d81278a91b2447a304dae654c64a5d5a70340fba811bb1cbd41019a2" #X labs control net replicate repo
29
- lora_url=model_dict[lora_model][1]
30
- assert control_net_img is not None, "Please add control net image"
31
- control_net_img=image_to_base64(control_net_img)
32
- inputs["lora_url"]=lora_url
33
- inputs["prompt"]+=", "+model_dict[lora_model][2]
34
- inputs["control_image"]=control_net_img
35
- inputs["control_type"]=control_net_type
36
- inputs["control_strengh"]=control_net_strength
37
- #other settings
38
- inputs["steps"]=30
39
- inputs["lora_strength"]=lora_scale
40
- inputs["negative_prompt"]= "low quality, ugly, distorted, artefacts"
41
- else:
42
- api_path=model_dict[lora_model][0]
43
- inputs["aspect_ratio"]=aspect_ratio
44
- inputs["prompt"]+=", "+model_dict[lora_model][2]
45
- inputs["num_inference_steps"]=28
46
- inputs["model"]="dev"
47
- inputs["lora_scale"]=lora_scale
48
-
49
- #run model
50
- output = replicate.run(
51
- api_path,
52
- input=inputs
53
- )
54
- print(output)
55
- return output[0]
56
-
57
-
58
-
59
-
60
- def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,model,lora_scale,
61
- num_outputs=1,guidance_scale=3.5,seed=None,
62
-
63
- ):
64
- print(prompt,lora_model,api_path,aspect_ratio)
65
-
66
- #if model=="dev":
67
- num_inference_steps=30
68
- if model=="schnell":
69
- num_inference_steps=5
70
-
71
- if lora_model is not None:
72
- api_path=model_dict[lora_model][0]
73
-
74
- inputs={
75
- "model": model,
76
- "prompt": detect_and_translate(prompt),
77
- "lora_scale":lora_scale,
78
- "aspect_ratio": aspect_ratio,
79
- "num_outputs":num_outputs,
80
- "num_inference_steps":num_inference_steps,
81
- "guidance_scale":guidance_scale,
82
- "output_format":"png",
83
- }
84
- if seed is not None:
85
- inputs["seed"]=seed
86
- output = replicate.run(
87
- api_path,
88
- input=inputs
89
- )
90
- print(output)
91
- return output[0]
92
- def replicate_bgcontrolnet(img,prompt,background_prompt, sampler_name= "DPM++ SDE Karras",
93
- negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
94
- ):
95
- img=image_to_base64(img)
96
- prompt=prompt+" ," +background_prompt
97
- output=replicate.run(
98
- "wolverinn/realistic-background:9f020c55e037529bf20ed1cb799d7aa290404cfbd45157686717ffc7ee511eab",
99
- input={
100
- "seed": -1,
101
- "image":img,
102
- "prompt":detect_and_translate(prompt),
103
- "sampler_name":sampler_name,
104
- "negative_prompt":negative_prompt
105
- }
106
- )
107
-
108
- return output["image"]
109
-
110
- def replicate_caption_api(image,model,context_text):
111
- print(model,context_text)
112
- base64_image = image_to_base64(image)
113
- if model=="blip":
114
- output = replicate.run(
115
- "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
116
- input={
117
- "image": base64_image,
118
- "caption": True,
119
- "question": context_text,
120
- "temperature": 1,
121
- "use_nucleus_sampling": False
122
- }
123
- )
124
- print(output)
125
-
126
- elif model=="llava-16":
127
- output = replicate.run(
128
- # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
129
- "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
130
- input={
131
- "image": base64_image,
132
- "top_p": 1,
133
- "prompt": context_text,
134
- "max_tokens": 1024,
135
- "temperature": 0.2
136
- }
137
- )
138
- print(output)
139
- output = "".join(output)
140
-
141
- elif model=="img2prompt":
142
- output = replicate.run(
143
- "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
144
- input={
145
- "image":base64_image
146
- }
147
- )
148
- print(output)
149
- return output
150
-
151
- def update_replicate_api_key(api_key):
152
- os.environ["REPLICATE_API_TOKEN"] = api_key
153
- return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
154
-
155
-
156
- def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
157
- output = replicate.run(
158
- "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
159
- input={
160
- "crop": crop,
161
- "seed": seed,
162
- "steps": steps,
163
- "category": category,
164
- # "force_dc": force_dc,
165
- "garm_img": numpy_to_base64( garm_img),
166
- "human_img": numpy_to_base64(human_img),
167
- #"mask_only": mask_only,
168
- "garment_des": garment_des
169
- }
170
- )
171
- print(output)
172
- return output
173
-
174
-
175
- from src.utils import create_zip
176
- from PIL import Image
177
-
178
-
179
- def process_images(files,model,context_text,token_string):
180
- images = []
181
- textbox =""
182
- for file in files:
183
- print(file)
184
- image = Image.open(file)
185
- if model=="None":
186
- caption="[Insert cap here]"
187
- else:
188
- caption = replicate_caption_api(image,model,context_text)
189
- textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
190
- images.append(image)
191
- #texts.append(textbox)
192
- zip_path=create_zip(files,textbox,token_string)
193
-
194
- return images, textbox,zip_path
195
-
196
- def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
197
- try:
198
- model = replicate.models.create(
199
- owner=owner,
200
- name=name,
201
- visibility=visibility,
202
- hardware=hardware,
203
- )
204
- print(model)
205
- return True
206
- except Exception as e:
207
- print(e)
208
- if "A model with that name and owner already exists" in str(e):
209
- return True
210
- return False
211
-
212
-
213
-
214
- def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
215
- ##Place holder for now
216
- BB_bucket_name="jarvisdataset"
217
- BB_defult="https://f005.backblazeb2.com/file/"
218
- if BB_defult not in zip_path:
219
- zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
220
- print(zip_path)
221
- training_logs = f"Using zip traning file at: {zip_path}\n"
222
- yield training_logs, None
223
- input={
224
- "steps": max_train_steps,
225
- "lora_rank": 16,
226
- "batch_size": 1,
227
- "autocaption": True,
228
- "trigger_word": token_string,
229
- "learning_rate": 0.0004,
230
- "seed": seed,
231
- "input_images": zip_path
232
- }
233
- print(training_destination)
234
- username,model_name=training_destination.split("/")
235
- assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
236
-
237
- print(input)
238
- try:
239
- training = replicate.trainings.create(
240
- destination=training_destination,
241
- version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
242
- input=input,
243
- )
244
-
245
- training_logs = f"Training started with model: {training_model}\n"
246
- training_logs += f"Destination: {training_destination}\n"
247
- training_logs += f"Seed: {seed}\n"
248
- training_logs += f"Token string: {token_string}\n"
249
- training_logs += f"Max train steps: {max_train_steps}\n"
250
-
251
- # Poll the training status
252
- while training.status != "succeeded":
253
- training.reload()
254
- training_logs += f"Training status: {training.status}\n"
255
- training_logs += f"{training.logs}\n"
256
- if training.status == "failed":
257
- training_logs += "Training failed!\n"
258
- return training_logs, training
259
-
260
- yield training_logs, None
261
- time.sleep(10) # Wait for 10 seconds before checking again
262
-
263
- training_logs += "Training completed!\n"
264
- if hf_repo_id and hf_token:
265
- training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
266
- # Here you would implement the logic to upload to Hugging Face
267
-
268
- traning_finnal=training.output
269
-
270
- # In a real scenario, you might want to download and display some result images
271
- # For now, we'll just return the original images
272
- #images = [Image.open(file) for file in files]
273
- _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
274
- traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
275
- yield training_logs, traning_finnal
276
-
277
- except Exception as e:
278
- yield f"An error occurred: {str(e)}", None
279
-
280
-
281
- def sam_segment(image,prompt,negative_prompt,adjustment_factor=-15):
282
- #img2 base64
283
- image = image_to_base64(image)
284
- output = replicate.run(
285
- "schananas/grounded_sam:ee871c19efb1941f55f66a3d7d960428c8a5afcb77449547fe8e5a3ab9ebc21c",
286
- input={
287
- "image": image,
288
- "mask_prompt": prompt,
289
- "adjustment_factor": adjustment_factor,
290
- "negative_mask_prompt":negative_prompt
291
- }
292
- )
293
- out_items={}
294
- for item in output:
295
- # https://replicate.com/schananas/grounded_sam/api#output-schema
296
- print(item)
297
- out_items[os.path.basename(item).split(".")[0]]=item
298
- return out_items
299
-
300
-
301
- def replicate_zest(img,material_img="https://replicate.delivery/pbxt/Kl23gJODaW7EuxrDzBG9dcgqRdMaYSWmBQ9UexnwPiL7AnIr/3.jpg"):
302
- if type(img)!=str:
303
- img=image_to_base64(img)
304
- if type(material_img)!=str:
305
- material_img=image_to_base64(material_img)
306
-
307
- output = replicate.run(
308
- "camenduru/zest:11abc0a411459327938957581151c642dd1bee4cefe443a9a63b230c4fbc0952",
309
- input={
310
- "input_image": img,
311
- "material_image":material_img
312
- }
313
- )
314
- print(output)
315
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import replicate
2
+ import os
3
+ from src.utils import image_to_base64 ,open_image_from_url, update_model_dicts, BB_uploadfile,numpy_to_base64
4
+ from src.deepl import detect_and_translate
5
+ import json
6
+ import time
7
+ style_json="model_dict.json"
8
+ model_dict=json.load(open(style_json,"r"))
9
+
10
+
11
+ def generate_image_control_net(prompt,lora_model,api_path,aspect_ratio,lora_scale,
12
+ use_control_net,control_net_type,control_net_img,control_net_strength,
13
+ num_outputs=1,guidance_scale=3.5,seed=None,
14
+ ):
15
+ print(prompt,lora_model,api_path,aspect_ratio,use_control_net)
16
+ inputs = {
17
+ "prompt": detect_and_translate(prompt),
18
+ "output_format": "png",
19
+ "num_outputs":num_outputs,
20
+ "guidance_scale": guidance_scale,
21
+ "output_quality": 100,
22
+ }
23
+ if seed is not None:
24
+ inputs["seed"]=seed
25
+
26
+
27
+ if use_control_net:
28
+ api_path= "xlabs-ai/flux-dev-controlnet:f2c31c31d81278a91b2447a304dae654c64a5d5a70340fba811bb1cbd41019a2" #X labs control net replicate repo
29
+ lora_url=model_dict[lora_model][1]
30
+ assert control_net_img is not None, "Please add control net image"
31
+ control_net_img=image_to_base64(control_net_img)
32
+ inputs["lora_url"]=lora_url
33
+ inputs["prompt"]+=", "+model_dict[lora_model][2]
34
+ inputs["control_image"]=control_net_img
35
+ inputs["control_type"]=control_net_type
36
+ inputs["control_strengh"]=control_net_strength
37
+ #other settings
38
+ inputs["steps"]=30
39
+ inputs["lora_strength"]=lora_scale
40
+ inputs["negative_prompt"]= "low quality, ugly, distorted, artefacts"
41
+ else:
42
+ api_path=model_dict[lora_model][0]
43
+ inputs["aspect_ratio"]=aspect_ratio
44
+ inputs["prompt"]+=", "+model_dict[lora_model][2]
45
+ inputs["num_inference_steps"]=28
46
+ inputs["model"]="dev"
47
+ inputs["lora_scale"]=lora_scale
48
+
49
+ #run model
50
+ output = replicate.run(
51
+ api_path,
52
+ input=inputs
53
+ )
54
+ print(output)
55
+ return output[0]
56
+
57
+
58
+
59
+
60
+ def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,model,lora_scale,
61
+ num_outputs=1,guidance_scale=3.5,seed=None,
62
+
63
+ ):
64
+ print(prompt,lora_model,api_path,aspect_ratio)
65
+
66
+ #if model=="dev":
67
+ num_inference_steps=30
68
+ if model=="schnell":
69
+ num_inference_steps=5
70
+
71
+ if lora_model is not None:
72
+ api_path=model_dict[lora_model][0]
73
+
74
+ inputs={
75
+ "model": model,
76
+ "prompt": detect_and_translate(prompt),
77
+ "lora_scale":lora_scale,
78
+ "aspect_ratio": aspect_ratio,
79
+ "num_outputs":num_outputs,
80
+ "num_inference_steps":num_inference_steps,
81
+ "guidance_scale":guidance_scale,
82
+ "output_format":"png",
83
+ }
84
+ if seed is not None:
85
+ inputs["seed"]=seed
86
+ output = replicate.run(
87
+ api_path,
88
+ input=inputs
89
+ )
90
+ print(output)
91
+ return output[0]
92
+ def replicate_bgcontrolnet(img,prompt,background_prompt, sampler_name= "DPM++ SDE Karras",
93
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
94
+ ):
95
+ img=image_to_base64(img)
96
+ prompt=prompt+" ," +background_prompt
97
+ output=replicate.run(
98
+ "wolverinn/realistic-background:9f020c55e037529bf20ed1cb799d7aa290404cfbd45157686717ffc7ee511eab",
99
+ input={
100
+ "seed": -1,
101
+ "image":img,
102
+ "prompt":prompt,
103
+ "sampler_name":sampler_name,
104
+ "negative_prompt":negative_prompt
105
+ }
106
+ )
107
+
108
+ return output["image"]
109
+
110
+ def replicate_caption_api(image,model,context_text):
111
+ print(model,context_text)
112
+ base64_image = image_to_base64(image)
113
+ if model=="blip":
114
+ output = replicate.run(
115
+ "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
116
+ input={
117
+ "image": base64_image,
118
+ "caption": True,
119
+ "question": context_text,
120
+ "temperature": 1,
121
+ "use_nucleus_sampling": False
122
+ }
123
+ )
124
+ print(output)
125
+
126
+ elif model=="llava-16":
127
+ output = replicate.run(
128
+ # "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
129
+ "yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
130
+ input={
131
+ "image": base64_image,
132
+ "top_p": 1,
133
+ "prompt": context_text,
134
+ "max_tokens": 1024,
135
+ "temperature": 0.2
136
+ }
137
+ )
138
+ print(output)
139
+ output = "".join(output)
140
+
141
+ elif model=="img2prompt":
142
+ output = replicate.run(
143
+ "methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
144
+ input={
145
+ "image":base64_image
146
+ }
147
+ )
148
+ print(output)
149
+ return output
150
+
151
+ def update_replicate_api_key(api_key):
152
+ os.environ["REPLICATE_API_TOKEN"] = api_key
153
+ return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
154
+
155
+
156
+ def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
157
+ output = replicate.run(
158
+ "cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
159
+ input={
160
+ "crop": crop,
161
+ "seed": seed,
162
+ "steps": steps,
163
+ "category": category,
164
+ # "force_dc": force_dc,
165
+ "garm_img": numpy_to_base64( garm_img),
166
+ "human_img": numpy_to_base64(human_img),
167
+ #"mask_only": mask_only,
168
+ "garment_des": garment_des
169
+ }
170
+ )
171
+ print(output)
172
+ return output
173
+
174
+
175
+ from src.utils import create_zip
176
+ from PIL import Image
177
+
178
+
179
+ def process_images(files,model,context_text,token_string):
180
+ images = []
181
+ textbox =""
182
+ for file in files:
183
+ print(file)
184
+ image = Image.open(file)
185
+ if model=="None":
186
+ caption="[Insert cap here]"
187
+ else:
188
+ caption = replicate_caption_api(image,model,context_text)
189
+ textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
190
+ images.append(image)
191
+ #texts.append(textbox)
192
+ zip_path=create_zip(files,textbox,token_string)
193
+
194
+ return images, textbox,zip_path
195
+
196
+ def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
197
+ try:
198
+ model = replicate.models.create(
199
+ owner=owner,
200
+ name=name,
201
+ visibility=visibility,
202
+ hardware=hardware,
203
+ )
204
+ print(model)
205
+ return True
206
+ except Exception as e:
207
+ print(e)
208
+ if "A model with that name and owner already exists" in str(e):
209
+ return True
210
+ return False
211
+
212
+
213
+
214
+ def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
215
+ ##Place holder for now
216
+ BB_bucket_name="jarvisdataset"
217
+ BB_defult="https://f005.backblazeb2.com/file/"
218
+ if BB_defult not in zip_path:
219
+ zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
220
+ print(zip_path)
221
+ training_logs = f"Using zip traning file at: {zip_path}\n"
222
+ yield training_logs, None
223
+ input={
224
+ "steps": max_train_steps,
225
+ "lora_rank": 16,
226
+ "batch_size": 1,
227
+ "autocaption": True,
228
+ "trigger_word": token_string,
229
+ "learning_rate": 0.0004,
230
+ "seed": seed,
231
+ "input_images": zip_path
232
+ }
233
+ print(training_destination)
234
+ username,model_name=training_destination.split("/")
235
+ assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
236
+
237
+ print(input)
238
+ try:
239
+ training = replicate.trainings.create(
240
+ destination=training_destination,
241
+ version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
242
+ input=input,
243
+ )
244
+
245
+ training_logs = f"Training started with model: {training_model}\n"
246
+ training_logs += f"Destination: {training_destination}\n"
247
+ training_logs += f"Seed: {seed}\n"
248
+ training_logs += f"Token string: {token_string}\n"
249
+ training_logs += f"Max train steps: {max_train_steps}\n"
250
+
251
+ # Poll the training status
252
+ while training.status != "succeeded":
253
+ training.reload()
254
+ training_logs += f"Training status: {training.status}\n"
255
+ training_logs += f"{training.logs}\n"
256
+ if training.status == "failed":
257
+ training_logs += "Training failed!\n"
258
+ return training_logs, training
259
+
260
+ yield training_logs, None
261
+ time.sleep(10) # Wait for 10 seconds before checking again
262
+
263
+ training_logs += "Training completed!\n"
264
+ if hf_repo_id and hf_token:
265
+ training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
266
+ # Here you would implement the logic to upload to Hugging Face
267
+
268
+ traning_finnal=training.output
269
+
270
+ # In a real scenario, you might want to download and display some result images
271
+ # For now, we'll just return the original images
272
+ #images = [Image.open(file) for file in files]
273
+ _= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
274
+ traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
275
+ yield training_logs, traning_finnal
276
+
277
+ except Exception as e:
278
+ yield f"An error occurred: {str(e)}", None
279
+
280
+
281
+ def sam_segment(image,prompt,negative_prompt,adjustment_factor=-15):
282
+ #img2 base64
283
+ image = image_to_base64(image)
284
+ output = replicate.run(
285
+ "schananas/grounded_sam:ee871c19efb1941f55f66a3d7d960428c8a5afcb77449547fe8e5a3ab9ebc21c",
286
+ input={
287
+ "image": image,
288
+ "mask_prompt": prompt,
289
+ "adjustment_factor": adjustment_factor,
290
+ "negative_mask_prompt":negative_prompt
291
+ }
292
+ )
293
+ out_items={}
294
+ for item in output:
295
+ # https://replicate.com/schananas/grounded_sam/api#output-schema
296
+ print(item)
297
+ out_items[os.path.basename(item).split(".")[0]]=item
298
+ return out_items
299
+
300
+
301
+ def replicate_zest(img,material_img="https://replicate.delivery/pbxt/Kl23gJODaW7EuxrDzBG9dcgqRdMaYSWmBQ9UexnwPiL7AnIr/3.jpg"):
302
+ if type(img)!=str:
303
+ img=image_to_base64(img)
304
+ if type(material_img)!=str:
305
+ material_img=image_to_base64(material_img)
306
+
307
+ output = replicate.run(
308
+ "camenduru/zest:11abc0a411459327938957581151c642dd1bee4cefe443a9a63b230c4fbc0952",
309
+ input={
310
+ "input_image": img,
311
+ "material_image":material_img
312
+ }
313
+ )
314
+ print(output)
315
+ return output
316
+
317
+
318
+ from src.utils import resize_image,find_closest_valid_dimension
319
+
320
+
321
+ light_source_options=[
322
+ "Use Background Image",
323
+ "Left Light",
324
+ "Right Light",
325
+ "Top Light",
326
+ "Bottom Light",
327
+ "Ambient"
328
+ ]
329
+
330
+
331
+ def replicate_iclight_BG(img,prompt,bg_img,light_source="Use Background Image",
332
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
333
+ ):
334
+
335
+ assert light_source in light_source_options, "Please select a correct ligh source option"
336
+ #if type(img)!=str:
337
+ # img=open_image_from_url(img)
338
+ #if type(bg_img)!=str:
339
+ # bg_img=open_image_from_url(bg_img)
340
+
341
+ width, height = img.size
342
+ print(width,height)
343
+ #print()
344
+ target_width = find_closest_valid_dimension(width)
345
+ target_height = find_closest_valid_dimension(height)
346
+ resized_img = resize_image(img, target_width, target_height)
347
+ img=image_to_base64(resized_img)
348
+
349
+ #if light_source=="Use Background Image":
350
+ bg_width, bg_height = bg_img.size
351
+ target_width = find_closest_valid_dimension(bg_width)
352
+ target_height = find_closest_valid_dimension(bg_height)
353
+ print(bg_img)
354
+ bg_img = resize_image(bg_img, target_width, target_height)
355
+ bg_img=image_to_base64(bg_img)
356
+
357
+
358
+ output=replicate.run(
359
+ "zsxkib/ic-light-background:60015df78a8a795470da6494822982140d57b150b9ef14354e79302ff89f69e3",
360
+ input={
361
+ "cfg": 2,
362
+ "steps": 25,
363
+ "width": width,
364
+ "height": height,
365
+ "prompt": prompt,
366
+ "light_source": light_source,
367
+ "highres_scale": 1.5,
368
+ "output_format": "png",
369
+ "subject_image": img,
370
+ "compute_normal": False,
371
+ "output_quality": 100,
372
+ "appended_prompt": "best quality",
373
+ "highres_denoise": 0.5,
374
+ "negative_prompt": "lowres, bad anatomy, bad hands, cropped, worst quality",
375
+ "background_image": bg_img,
376
+ "number_of_images": 1
377
+ }
378
+
379
+ )
380
+ return output[0]
src/utils.py CHANGED
@@ -12,6 +12,53 @@ import json
12
  import gradio as gr
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def convert_to_pil(img):
16
  if isinstance(img, np.ndarray):
17
  img = Image.fromarray(img)
 
12
  import gradio as gr
13
 
14
 
15
+ def resize_image(img, target_width, target_height):
16
+ """Resizes an image while maintaining aspect ratio.
17
+
18
+ Args:
19
+ img: The PIL Image object to resize.
20
+ target_width: The desired width.
21
+ target_height: The desired height.
22
+
23
+ Returns:
24
+ The resized PIL Image object.
25
+ """
26
+ width, height = img.size
27
+ aspect_ratio = width / height
28
+
29
+ # Calculate new dimensions based on aspect ratio and target dimensions
30
+ if width > height:
31
+ new_width = target_width
32
+ new_height = int(new_width / aspect_ratio)
33
+ else:
34
+ new_height = target_height
35
+ new_width = int(new_height * aspect_ratio)
36
+
37
+ # Resize the image
38
+ resized_img = img.resize((new_width, new_height))
39
+ return resized_img
40
+
41
+ # Example usage:
42
+ # Assuming img is your PIL Image object
43
+ # target_width = 512
44
+ # target_height = 512
45
+ # resized_img = resize_image(img, target_width, target_height)
46
+ # resized_img.show()
47
+
48
+ def find_closest_valid_dimension(dimension, valid_dimensions= [256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024]):
49
+ """Finds the closest valid dimension from a list of valid dimensions.
50
+
51
+ Args:
52
+ dimension: The target dimension.
53
+ valid_dimensions: A list of valid dimensions.
54
+
55
+ Returns:
56
+ The closest valid dimension.
57
+ """
58
+ closest_dimension = min(valid_dimensions, key=lambda x: abs(x - dimension))
59
+ return closest_dimension
60
+
61
+
62
  def convert_to_pil(img):
63
  if isinstance(img, np.ndarray):
64
  img = Image.fromarray(img)