AkashKumarave commited on
Commit
cea630a
·
verified ·
1 Parent(s): 378da96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -56
app.py CHANGED
@@ -1,67 +1,85 @@
1
  import gradio as gr
2
  import numpy as np
3
- import cv2
4
- import onnxruntime
5
- from insightface.app import FaceAnalysis
6
- from pathlib import Path
7
 
8
- # Initialize Face Analysis
9
- face_analyzer = FaceAnalysis(name="buffalo_l")
10
- face_analyzer.prepare(ctx_id=0, det_size=(640, 640))
11
 
12
- # Load Face Swapper Model
13
- MODEL_PATH = Path("models/inswapper_128.onnx")
14
- if not MODEL_PATH.exists():
15
- raise FileNotFoundError("Model file inswapper_128.onnx not found.")
 
 
16
 
17
- session = onnxruntime.InferenceSession(str(MODEL_PATH))
 
18
 
19
- def swap_faces(source_img, target_img):
20
- """Perform face swapping using the ONNX model."""
21
- try:
22
- # Convert images to correct format
23
- source_img = cv2.cvtColor(np.array(source_img), cv2.COLOR_RGB2BGR)
24
- target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
25
 
26
- # Detect faces
27
- source_faces = face_analyzer.get(source_img)
28
- target_faces = face_analyzer.get(target_img)
 
 
29
 
30
- if not source_faces or not target_faces:
31
- return "No faces detected in one or both images."
32
- if len(source_faces) > 1 or len(target_faces) > 1:
33
- return "Multiple faces detected; only one face per image is supported."
 
 
34
 
35
- source_face = source_faces[0]
36
- target_face = target_faces[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Prepare input data for ONNX model
39
- input_data = {
40
- "target_image": target_img,
41
- "target_face": target_face.embedding,
42
- "source_face": source_face.embedding
43
- }
44
-
45
- # Run the ONNX model
46
- result = session.run(None, input_data)[0]
47
-
48
- # Convert result to image format
49
- result_img = np.clip(result * 255, 0, 255).astype(np.uint8)
50
- result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
51
-
52
- return result_img
53
- except Exception as e:
54
- return f"Face swap failed: {e}"
55
-
56
- # Gradio UI
57
- with gr.Blocks() as demo:
58
- gr.Markdown("# Face Swap Tool 🚀")
59
- with gr.Row():
60
- input_source = gr.Image(label="Source Face", type="pil")
61
- input_target = gr.Image(label="Target Image", type="pil")
62
- btn_swap = gr.Button("Swap Faces")
63
- output_image = gr.Image(label="Swapped Face")
64
- btn_swap.click(swap_faces, inputs=[input_source, input_target], outputs=output_image)
65
-
66
- # Launch Gradio App
67
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import random
4
+ import torch
5
+ from diffusers import DiffusionPipeline
 
6
 
7
+ # Ensure the model runs on CPU
8
+ device = "cpu"
9
+ dtype = torch.float32 # Use float32 for CPU compatibility
10
 
11
+ # Load model from Hugging Face (it will cache locally in Hugging Face Spaces)
12
+ pipe = DiffusionPipeline.from_pretrained(
13
+ "black-forest-labs/FLUX.1-schnell",
14
+ torch_dtype=dtype,
15
+ low_cpu_mem_usage=True
16
+ ).to(device)
17
 
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ MAX_IMAGE_SIZE = 1024
20
 
21
+ def infer(prompt, seed=42, randomize_seed=False, width=512, height=512, num_inference_steps=4):
22
+ if randomize_seed:
23
+ seed = random.randint(0, MAX_SEED)
24
+ generator = torch.Generator(device=device).manual_seed(seed)
25
+ image = pipe(
26
+ prompt=prompt,
27
+ width=width,
28
+ height=height,
29
+ num_inference_steps=num_inference_steps,
30
+ generator=generator,
31
+ guidance_scale=0.0
32
+ ).images[0]
33
+ return image, seed
34
 
35
+ examples = [
36
+ "a tiny astronaut hatching from an egg on the moon",
37
+ "a cat holding a sign that says hello world",
38
+ "an anime illustration of a wiener schnitzel",
39
+ ]
40
 
41
+ css="""
42
+ #col-container {
43
+ margin: 0 auto;
44
+ max-width: 520px;
45
+ }
46
+ """
47
 
48
+ with gr.Blocks(css=css) as demo:
49
+ with gr.Column(elem_id="col-container"):
50
+ gr.Markdown("""# FLUX.1 [schnell]
51
+ 12B param rectified flow transformer distilled from FLUX.1 [pro]
52
+ """)
53
+
54
+ with gr.Row():
55
+ prompt = gr.Text(label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False)
56
+ run_button = gr.Button("Run", scale=0)
57
+
58
+ result = gr.Image(label="Result", show_label=False)
59
+
60
+ with gr.Accordion("Advanced Settings", open=False):
61
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
62
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
63
+
64
+ with gr.Row():
65
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
66
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
67
+
68
+ num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=50, step=1, value=4)
69
+
70
+ gr.Examples(
71
+ examples=examples,
72
+ fn=infer,
73
+ inputs=[prompt],
74
+ outputs=[result, seed],
75
+ cache_examples="lazy"
76
+ )
77
+
78
+ gr.on(
79
+ triggers=[run_button.click, prompt.submit],
80
+ fn=infer,
81
+ inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
82
+ outputs=[result, seed]
83
+ )
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  demo.launch()