AkashKumarave commited on
Commit
24c793a
·
verified ·
1 Parent(s): 66ae8cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -45
app.py CHANGED
@@ -1,55 +1,71 @@
1
  import torch
2
- import cv2
3
  import numpy as np
4
- import gradio as gr
5
- from diffusers import StableDiffusionPipeline, DDIMScheduler, IPAdapterFaceIDPipeline
6
- from insightface.app import FaceAnalysis
7
- from huggingface_hub import hf_hub_download
 
 
 
8
 
9
- device = "cpu"
10
- dtype = torch.float32
11
 
12
- # Initialize Face Detection
13
- face_app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"])
14
- face_app.prepare(ctx_id=0, det_size=(320, 320))
 
15
 
16
- # Download IP-Adapter FaceID Weights
17
- ip_adapter_path = hf_hub_download(
18
- repo_id="h94/IP-Adapter",
19
- filename="ip-adapter_sd15_faceid.bin",
20
- subfolder="models"
21
  )
 
 
 
22
 
23
- # Load Pipeline
24
- pipe = IPAdapterFaceIDPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype)
25
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
26
- pipe.load_ip_adapter(ip_adapter_path)
27
- pipe.to(device)
28
-
29
- def generate(upload_image, prompt):
30
- if upload_image is None:
31
- return "No Image Uploaded", None
32
 
33
- img = cv2.cvtColor(np.array(upload_image), cv2.COLOR_RGB2BGR)
34
- faces = face_app.get(img)
35
 
36
- if len(faces) == 0:
37
- return "No Face Detected", None
38
-
39
- face_emb = torch.tensor(faces[0].normed_embedding).unsqueeze(0).to(device)
40
- result = pipe(prompt=prompt, face_embeds=face_emb, num_inference_steps=30).images[0]
41
- return "Image Generated Successfully", result
42
-
43
- demo = gr.Interface(
44
- generate,
45
- inputs=[
46
- gr.Image(type="pil", label="Upload Image"),
47
- gr.Textbox(label="Prompt")
48
- ],
49
- outputs=[
50
- gr.Textbox(label="Status"),
51
- gr.Image(label="Generated Image")
52
- ]
53
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- demo.launch()
 
 
1
  import torch
 
2
  import numpy as np
3
+ from flask import Flask, request, jsonify
4
+ from diffusers import DiffusionPipeline
5
+ from PIL import Image
6
+ import base64
7
+ import io
8
+ import gc
9
+ import os
10
 
11
+ app = Flask(__name__)
 
12
 
13
+ # Device set to CPU and optimize threading
14
+ device = "cpu"
15
+ torch.set_num_threads(max(1, os.cpu_count() or 4)) # Use all available CPU threads
16
+ torch.set_num_interop_threads(max(1, os.cpu_count() or 4))
17
 
18
+ # Load model on CPU with optimizations
19
+ pipe = DiffusionPipeline.from_pretrained(
20
+ "stabilityai/sdxl-turbo",
21
+ use_safetensors=True,
22
+ torch_dtype=torch.float16 # FP16 for memory efficiency
23
  )
24
+ pipe = pipe.to(device)
25
+ pipe.eval() # Evaluation mode to save memory
26
+ pipe.unet.enable_model_cpu_offload() # Offload U-Net to CPU memory if possible
27
 
28
+ def infer(prompt, steps=1, seed=0): # Reduced to 1 step for speed
29
+ if not prompt or len(prompt.split()) > 77: # Enforce 77 token limit
30
+ return "Prompt missing or exceeds 77 tokens!", 0
 
 
 
 
 
 
31
 
32
+ # Set seed
33
+ generator = torch.Generator(device=device).manual_seed(seed) if seed != 0 else torch.Generator(device=device).manual_seed(np.random.randint(0, 2**32 - 1))
34
 
35
+ # Generate image on CPU with minimal steps
36
+ with torch.no_grad(): # Disable gradient computation
37
+ image = pipe(
38
+ prompt=prompt,
39
+ num_inference_steps=steps, # Reduced to 1 step
40
+ guidance_scale=0.0,
41
+ height=512,
42
+ width=512,
43
+ generator=generator,
44
+ output_type="pil",
45
+ num_images_per_prompt=1
46
+ ).images[0]
47
+
48
+ gc.collect() # Free memory immediately
49
+ torch.cuda.empty_cache() # No-op on CPU but included for consistency
50
+ return image, seed
51
+
52
+ @app.route('/generate', methods=['POST'])
53
+ def generate():
54
+ prompt = request.form.get('prompt')
55
+ steps = int(request.form.get('steps', 1)) # Default to 1 step
56
+ seed = int(request.form.get('seed', 0)) # Default to 0 or random
57
+
58
+ result, seed_used = infer(prompt, steps, seed)
59
+
60
+ if isinstance(result, str): # Error message
61
+ return jsonify({'error': result, 'seed': seed_used}), 400
62
+
63
+ # Convert image to base64
64
+ buffered = io.BytesIO()
65
+ result.save(buffered, format="PNG")
66
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
67
+
68
+ return jsonify({'image': img_str, 'seed': seed_used})
69
 
70
+ if __name__ == '__main__':
71
+ app.run(host='0.0.0.0', port=8000)