Spaces:
Running
Running
| import os.path | |
| import gradio as gr | |
| import json | |
| import requests | |
| import time | |
| from gradio_modal import Modal | |
| from io import BytesIO | |
| TRYON_SERVER_HOST = "https://prod.server.tryonlabs.ai" | |
| TRYON_SERVER_PORT = "80" | |
| if TRYON_SERVER_PORT == "80": | |
| TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}" | |
| else: | |
| TRYON_SERVER_URL = f"{TRYON_SERVER_HOST}:{TRYON_SERVER_PORT}" | |
| TRYON_SERVER_API_URL = f"{TRYON_SERVER_URL}/api/v1/" | |
| def start_model_swap(input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps): | |
| # make a request to TryOn Server | |
| # 1. create an experiment image | |
| print("inputs:", input_image, prompt, cls, seed, guidance_scale, num_results, strength, inference_steps) | |
| if input_image is None: | |
| raise gr.Error("Select an image!") | |
| if prompt is None or prompt == "": | |
| raise gr.Error("Enter a prompt!") | |
| token = load_token() | |
| if token is None or token == "": | |
| raise gr.Error("You need to login first!") | |
| else: | |
| login(token) | |
| byte_io = BytesIO() | |
| input_image.save(byte_io, 'png') | |
| byte_io.seek(0) | |
| r = requests.post(f"{TRYON_SERVER_API_URL}experiment_image/", | |
| files={"image": ( | |
| 'ei_image.png', | |
| byte_io, | |
| 'image/png' | |
| )}, | |
| data={ | |
| "type": "model", | |
| "preprocess": "false"}, | |
| headers={ | |
| "Authorization": f"Bearer {token}" | |
| }) | |
| # print(r.json()) | |
| if r.status_code == 200 or r.status_code == 201: | |
| print("Experiment image created successfully", r.json()) | |
| res = r.json() | |
| # 2 create an experiment | |
| r2 = requests.post(f"{TRYON_SERVER_API_URL}experiment/", | |
| data={ | |
| "model_id": res['id'], | |
| "action": "model_swap", | |
| "params": json.dumps({"prompt": prompt, | |
| "guidance_scale": guidance_scale, | |
| "strength": strength, | |
| "num_inference_steps": inference_steps, | |
| "seed": seed, | |
| "garment_class": f"{cls} garment", | |
| "negative_prompt": "(hands:1.15), disfigured, ugly, bad, immature" | |
| ", cartoon, anime, 3d, painting, b&w, (ugly)," | |
| " (pixelated), watermark, glossy, smooth, " | |
| "earrings, necklace", | |
| "num_results": num_results}) | |
| }, | |
| headers={ | |
| "Authorization": f"Bearer {token}" | |
| }) | |
| if r2.status_code == 200 or r2.status_code == 201: | |
| # 3. keep checking the status of the experiment | |
| res2 = r2.json() | |
| print("Experiment created successfully", res2) | |
| time.sleep(10) | |
| experiment = res2['experiment'] | |
| status = fetch_experiment_status(experiment_id=experiment['id'], token=token) | |
| status_status = status['status'] | |
| while status_status == "running": | |
| time.sleep(10) | |
| status = fetch_experiment_status(experiment_id=experiment['id'], token=token) | |
| status_status = status['status'] | |
| print(f"Current status: {status_status}") | |
| if status['status'] == "success": | |
| print("Experiment successful") | |
| print(f"Results:{status['result_images']}") | |
| return status['result_images'] | |
| elif status['status'] == "failed": | |
| print("Experiment failed") | |
| raise gr.Error("Experiment failed") | |
| else: | |
| print(f"Error: {r2.text}") | |
| raise gr.Error(f"Failure: {r2.text}") | |
| else: | |
| print(f"Error: {r.text}") | |
| raise gr.Error(f"Failure: {r.text}") | |
| def fetch_experiment_status(experiment_id, token): | |
| print(f"experiment id:{experiment_id}") | |
| r3 = requests.get(f"{TRYON_SERVER_API_URL}experiment/{experiment_id}/", | |
| headers={ | |
| "Authorization": f"Bearer {token}" | |
| }) | |
| if r3.status_code == 200: | |
| res = r3.json() | |
| if res['status'] == "running": | |
| return {"status": "running"} | |
| elif res['status'] == "success": | |
| experiment = r3.json()['experiment'] | |
| result_images = [f"{TRYON_SERVER_URL}/{experiment['result']['image_url']}"] | |
| if len(experiment['results']) > 0: | |
| for result in experiment['results']: | |
| result_images.append(f"{TRYON_SERVER_URL}/{result['image_url']}") | |
| return {"status": "success", "result_images": result_images} | |
| elif res['status'] == "failed": | |
| return {"status": "failed"} | |
| else: | |
| print(f"Error: {r3.text}") | |
| return {"status": "failed"} | |
| def get_user_credits(token): | |
| if token == "": | |
| return None | |
| r = requests.get(f"{TRYON_SERVER_API_URL}user/get/", headers={ | |
| "Authorization": f"Bearer {token}" | |
| }) | |
| if r.status_code == 200: | |
| res = r.json() | |
| return res['credits'] | |
| else: | |
| print(f"Error: {r.text}") | |
| return None | |
| def load_token(): | |
| if os.path.exists(".token"): | |
| with open(".token", "r") as f: | |
| return json.load(f)['token'] | |
| else: | |
| return None | |
| def save_token(access_token): | |
| if access_token != "": | |
| with open(".token", "w") as f: | |
| json.dump({"token": access_token}, f) | |
| else: | |
| raise gr.Error("No token provided!") | |
| def is_logged_in(): | |
| loaded_token = load_token() | |
| if loaded_token is None or loaded_token == "": | |
| return False | |
| else: | |
| return True | |
| def login(token): | |
| print("logging in...") | |
| # validate token | |
| r = requests.post(f"{TRYON_SERVER_URL}/api/token/verify/", data={"token": token}) | |
| if r.status_code == 200: | |
| save_token(token) | |
| return True | |
| else: | |
| raise gr.Error("Login failed") | |
| def logout(): | |
| print("logged out") | |
| with open(".token", "w") as f: | |
| json.dump({"token": ""}, f) | |
| return [False, ""] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| #credits-col-container{ | |
| display:flex; | |
| justify-content: right; | |
| align-items: center; | |
| font-size: 24px; | |
| margin-right: 1rem; | |
| } | |
| #login-modal{ | |
| max-width: 728px; | |
| margin: 0 auto; | |
| margin-top: 1rem; | |
| margin-bottom: 1rem; | |
| } | |
| #login-logout-btn{ | |
| display:inline; | |
| max-width: 124px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Default()) as demo: | |
| print("is logged in:", is_logged_in()) | |
| logged_in = gr.State(is_logged_in()) | |
| if os.path.exists(".token"): | |
| with open(".token", "r") as f: | |
| user_token = gr.State(json.load(f)["token"]) | |
| else: | |
| user_token = gr.State("") | |
| with Modal(visible=False) as modal: | |
| def rerender1(user_token1): | |
| with gr.Column(elem_id="login-modal"): | |
| access_token = gr.Textbox( | |
| label="Token", | |
| lines=1, | |
| value=user_token1, | |
| type="password", | |
| placeholder="Enter your access token here!", | |
| info="Visit https://playground.tryonlabs.ai to retrieve your access token." | |
| ) | |
| login_submit_btn = gr.Button("Login", scale=1, variant='primary') | |
| login_submit_btn.click( | |
| fn=lambda access_token: (login(access_token), Modal(visible=False), access_token), | |
| inputs=[access_token], outputs=[logged_in, modal, user_token], | |
| concurrency_limit=1) | |
| with gr.Row(elem_id="col-container"): | |
| with gr.Column(): | |
| gr.Markdown(f""" | |
| # Model Swap AI | |
| ## by TryOn Labs (https://www.tryonlabs.ai) | |
| Swap a human model with a artificial model generated by Artificial Model while keeping the garment intact. | |
| """) | |
| def rerender(is_logged_in): | |
| with gr.Column(): | |
| if not is_logged_in: | |
| with gr.Row(elem_id="credits-col-container"): | |
| login_btn = gr.Button(value="Login", variant='primary', elem_id="login-logout-btn", size="sm") | |
| login_btn.click(lambda: Modal(visible=True), None, modal) | |
| else: | |
| user_credits = get_user_credits(load_token()) | |
| print("user_credits", user_credits) | |
| gr.HTML(f"""<div><p id="credits-col-container">Your Credits: | |
| {user_credits if user_credits is not None else "0"}</p> | |
| <p style="text-align: right;">Visit <a href="https://playground.tryonlabs.ai"> | |
| TryOn AI Playground</a> to acquire more credits</p></div>""") | |
| with gr.Row(elem_id="credits-col-container"): | |
| logout_btn = gr.Button(value="Logout", scale=1, variant='primary', size="sm", | |
| elem_id="login-logout-btn") | |
| logout_btn.click(fn=logout, inputs=None, outputs=[logged_in, user_token], concurrency_limit=1) | |
| with gr.Column(elem_id="col-container"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Original image", type='pil', height="400px", show_label=True) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| lines=3, | |
| placeholder="Enter your prompt here!", | |
| ) | |
| dropdown = gr.Dropdown(["upper", "lower", "dress"], value="upper", label="Retain garment", | |
| info="Select the garment type you want to retain in the generated image!") | |
| gallery = gr.Gallery( | |
| label="Generated images", show_label=True, elem_id="gallery" | |
| , columns=[3], rows=[1], object_fit="contain", height="auto") | |
| # output_image = gr.Image(label="Swapped model", type='pil', height="400px", show_label=True, | |
| # show_download_button=True) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=-1, interactive=True, minimum=-1) | |
| guidance_scale = gr.Number(label="Guidance Scale", value=7.5, interactive=True, minimum=0.0, | |
| maximum=10.0, | |
| step=0.1) | |
| num_results = gr.Number(label="Number of results", value=2, minimum=1, maximum=5) | |
| with gr.Row(): | |
| strength = gr.Slider(0.00, 1.00, value=0.99, label="Strength", | |
| info="Choose between 0.00 and 1.00", step=0.01, interactive=True) | |
| inference_steps = gr.Number(label="Inference Steps", value=20, interactive=True, minimum=1, step=1) | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit", variant='primary', scale=1) | |
| reset_button = gr.ClearButton(value="Reset", scale=1) | |
| gr.on( | |
| triggers=[submit_button.click], | |
| fn=start_model_swap, | |
| inputs=[input_image, prompt, dropdown, seed, guidance_scale, num_results, strength, inference_steps], | |
| outputs=[gallery] | |
| ) | |
| reset_button.click( | |
| fn=lambda: (None, None, "upper", None, -1, 7.5, 2, 0.99, 20), | |
| inputs=[], | |
| outputs=[input_image, prompt, dropdown, gallery, seed, guidance_scale, | |
| num_results, strength, inference_steps], | |
| concurrency_limit=1, | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |