|
|
import gradio as gr |
|
|
import requests |
|
|
import os |
|
|
import tempfile |
|
|
import time |
|
|
import json |
|
|
from PIL import Image |
|
|
import io |
|
|
import base64 |
|
|
import urllib.request |
|
|
from google.cloud import storage |
|
|
from google.oauth2 import service_account |
|
|
import json |
|
|
import shutil |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
API_HOST = os.environ.get("FACESWAP_API_HOST", "162.243.118.138") |
|
|
API_PORT = os.environ.get("FACESWAP_API_PORT", "8080") |
|
|
API_URL = f"http://{API_HOST}:{API_PORT}/faceswap" |
|
|
|
|
|
logger.info(f"Using API endpoint: {API_URL}") |
|
|
|
|
|
def setup_gcs_auth(): |
|
|
"""Set up GCS authentication using service account from environment variable""" |
|
|
service_account_json = os.environ.get("SERVICE_ACCOUNT") |
|
|
if service_account_json: |
|
|
try: |
|
|
service_account_info = json.loads(service_account_json) |
|
|
credentials = service_account.Credentials.from_service_account_info(service_account_info) |
|
|
return credentials |
|
|
except Exception as e: |
|
|
logger.error(f"Error setting up GCS auth: {e}") |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
gcs_credentials = setup_gcs_auth() |
|
|
|
|
|
def download_from_gcs(gcs_url): |
|
|
"""Download a file from GCS using authenticated client""" |
|
|
try: |
|
|
|
|
|
if gcs_url.startswith("https://storage.googleapis.com/"): |
|
|
path = gcs_url.replace("https://storage.googleapis.com/", "") |
|
|
bucket_name, blob_path = path.split("/", 1) |
|
|
|
|
|
storage_client = storage.Client(credentials=gcs_credentials) if gcs_credentials else storage.Client() |
|
|
bucket = storage_client.bucket(bucket_name) |
|
|
blob = bucket.blob(blob_path) |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".gif") |
|
|
blob.download_to_filename(temp_file.name) |
|
|
|
|
|
return temp_file.name |
|
|
else: |
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".gif") |
|
|
urllib.request.urlretrieve(gcs_url, temp_file.name) |
|
|
return temp_file.name |
|
|
except Exception as e: |
|
|
logger.error(f"Error downloading from GCS: {e}") |
|
|
return None |
|
|
|
|
|
def faceswap_process(gif_file, face_file): |
|
|
"""Process faceswap by calling the API""" |
|
|
if gif_file is None or face_file is None: |
|
|
return None, "Please upload both a GIF and a face image." |
|
|
|
|
|
temp_gif = tempfile.NamedTemporaryFile(suffix='.gif', delete=False) |
|
|
temp_face = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(gif_file, str): |
|
|
|
|
|
shutil.copy(gif_file, temp_gif.name) |
|
|
elif hasattr(gif_file, 'save'): |
|
|
|
|
|
gif_file.save(temp_gif.name, 'GIF') |
|
|
elif hasattr(gif_file, 'read'): |
|
|
|
|
|
with open(temp_gif.name, 'wb') as f: |
|
|
f.write(gif_file.read()) |
|
|
else: |
|
|
|
|
|
from PIL import Image |
|
|
Image.fromarray(gif_file).save(temp_gif.name, 'GIF') |
|
|
|
|
|
|
|
|
if isinstance(face_file, str): |
|
|
|
|
|
shutil.copy(face_file, temp_face.name) |
|
|
elif hasattr(face_file, 'save'): |
|
|
|
|
|
face_file.save(temp_face.name, 'JPEG') |
|
|
elif hasattr(face_file, 'read'): |
|
|
|
|
|
with open(temp_face.name, 'wb') as f: |
|
|
f.write(face_file.read()) |
|
|
else: |
|
|
|
|
|
from PIL import Image |
|
|
Image.fromarray(face_file).save(temp_face.name, 'JPEG') |
|
|
|
|
|
files = { |
|
|
'gif_file': ('input.gif', open(temp_gif.name, 'rb'), 'image/gif'), |
|
|
'face_file': ('face.jpg', open(temp_face.name, 'rb'), 'image/jpeg') |
|
|
} |
|
|
|
|
|
start_time = time.time() |
|
|
response = requests.post(API_URL, files=files) |
|
|
|
|
|
|
|
|
os.unlink(temp_gif.name) |
|
|
os.unlink(temp_face.name) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
|
|
|
if result.get('status') == 'success': |
|
|
output_url = result.get('url') |
|
|
time_taken = result.get('time_taken', "Unknown") |
|
|
|
|
|
|
|
|
local_gif_path = download_from_gcs(output_url) |
|
|
|
|
|
if local_gif_path: |
|
|
with open(local_gif_path, 'rb') as f: |
|
|
gif_data = f.read() |
|
|
|
|
|
gif_base64 = base64.b64encode(gif_data).decode('utf-8') |
|
|
|
|
|
os.unlink(local_gif_path) |
|
|
|
|
|
return f"data:image/gif;base64,{gif_base64}", f"✅ Faceswap completed in {time_taken}!" |
|
|
else: |
|
|
return None, "❌ Error downloading the result GIF" |
|
|
else: |
|
|
return None, f"❌ Error: {result.get('message', 'Unknown error')}" |
|
|
else: |
|
|
return None, f"❌ API Error: Status code {response.status_code}" |
|
|
|
|
|
except Exception as e: |
|
|
try: |
|
|
os.unlink(temp_gif.name) |
|
|
os.unlink(temp_face.name) |
|
|
except: |
|
|
pass |
|
|
|
|
|
logger.exception("Error in faceswap_process") |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.centered-title { |
|
|
text-align: center; |
|
|
margin-bottom: 1.5rem; |
|
|
} |
|
|
.container { |
|
|
max-width: 900px; |
|
|
margin: 0 auto; |
|
|
} |
|
|
.output-container img { |
|
|
max-width: 100%; |
|
|
height: auto; |
|
|
display: block; |
|
|
margin: 0 auto; |
|
|
} |
|
|
/* Loading animation for output area only */ |
|
|
.loading-container { |
|
|
text-align: center; |
|
|
padding: 30px; |
|
|
} |
|
|
.loading-spinner { |
|
|
display: inline-block; |
|
|
width: 50px; |
|
|
height: 50px; |
|
|
border: 5px solid rgba(0, 0, 0, 0.1); |
|
|
border-radius: 50%; |
|
|
border-top-color: #3498db; |
|
|
animation: spin 1s ease-in-out infinite; |
|
|
} |
|
|
@keyframes spin { |
|
|
to { transform: rotate(360deg); } |
|
|
} |
|
|
.loading-text { |
|
|
margin-top: 15px; |
|
|
font-weight: bold; |
|
|
color: #555; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(title="Easel x GifSwap", css=custom_css) as demo: |
|
|
with gr.Column(elem_classes="container"): |
|
|
gr.HTML( |
|
|
""" |
|
|
<div class="centered-title"> |
|
|
<h1>Easel x GifSwap</h1> |
|
|
<p>Upload a GIF and a face image to swap faces in the GIF.</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
gif_input = gr.Image(label="Upload GIF", type="filepath", |
|
|
image_mode="RGB", sources=["upload"], |
|
|
elem_id="gif_input") |
|
|
face_input = gr.Image(label="Upload Face Image", type="filepath", |
|
|
image_mode="RGB", sources=["upload"], |
|
|
elem_id="face_input") |
|
|
submit_btn = gr.Button("Swap Face", variant="primary") |
|
|
|
|
|
with gr.Column(elem_classes="output-container"): |
|
|
|
|
|
output_html = gr.HTML(label="Output GIF") |
|
|
output_text = gr.Textbox(label="Status") |
|
|
|
|
|
def process_and_display(gif_file, face_file): |
|
|
if gif_file is None or face_file is None: |
|
|
return ( |
|
|
"", |
|
|
"Please upload both a GIF and a face image.", |
|
|
gr.update(interactive=True) |
|
|
) |
|
|
|
|
|
|
|
|
base64_data, message = faceswap_process(gif_file, face_file) |
|
|
|
|
|
if base64_data: |
|
|
|
|
|
html = f""" |
|
|
<div style="text-align: center;"> |
|
|
<img src="{base64_data}" style="max-width:100%; height:auto;" alt="Faceswap Result" autoplay loop> |
|
|
</div> |
|
|
""" |
|
|
return ( |
|
|
html, |
|
|
message, |
|
|
gr.update(interactive=True) |
|
|
) |
|
|
else: |
|
|
return ( |
|
|
"", |
|
|
message, |
|
|
gr.update(interactive=True) |
|
|
) |
|
|
|
|
|
|
|
|
def on_submit_click(): |
|
|
loading_html = """ |
|
|
<div class="loading-container"> |
|
|
<div class="loading-spinner"></div> |
|
|
<div class="loading-text">Generating face-swapped GIF...</div> |
|
|
</div> |
|
|
""" |
|
|
return ( |
|
|
loading_html, |
|
|
"Processing...", |
|
|
gr.update(interactive=False) |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=on_submit_click, |
|
|
inputs=None, |
|
|
outputs=[output_html, output_text, submit_btn] |
|
|
).then( |
|
|
fn=process_and_display, |
|
|
inputs=[gif_input, face_input], |
|
|
outputs=[output_html, output_text, submit_btn] |
|
|
) |
|
|
|
|
|
|
|
|
def check_inputs(gif, face): |
|
|
if gif is not None and face is not None: |
|
|
return gr.update(interactive=True) |
|
|
return gr.update(interactive=False) |
|
|
|
|
|
gif_input.change( |
|
|
fn=check_inputs, |
|
|
inputs=[gif_input, face_input], |
|
|
outputs=[submit_btn] |
|
|
) |
|
|
|
|
|
face_input.change( |
|
|
fn=check_inputs, |
|
|
inputs=[gif_input, face_input], |
|
|
outputs=[submit_btn] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |