AkashKumarave commited on
Commit
9257314
·
verified ·
1 Parent(s): ab51ed0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -29
app.py CHANGED
@@ -1,32 +1,68 @@
1
- import gradio as gr
2
  import torch
3
- import modin.pandas as pd
4
  import numpy as np
5
- from diffusers import DiffusionPipeline
6
-
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- torch.cuda.empty_cache()
12
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
13
- pipe.enable_xformers_memory_efficient_attention()
14
- pipe = pipe.to(device)
15
- torch.cuda.empty_cache()
16
- else:
17
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
18
- pipe = pipe.to(device)
19
-
20
- def genie (prompt, steps, seed):
21
- generator = np.random.seed(0) if seed == 0 else torch.manual_seed(seed)
22
- int_image = pipe(prompt=prompt, generator=generator, num_inference_steps=steps, guidance_scale=0.0).images[0]
23
- return int_image
 
 
 
 
 
 
 
 
 
 
24
 
25
- gr.Interface(fn=genie, inputs=[gr.Textbox(label='What you want the AI to generate. 77 Token Limit.'),
26
- gr.Slider(1, maximum=5, value=2, step=1, label='Number of Iterations'),
27
- gr.Slider(minimum=0, step=1, maximum=999999999999999999, randomize=True),
28
- ],
29
- outputs='image',
30
- title="Stable Diffusion Turbo CPU or GPU",
31
- description="SDXL Turbo CPU or GPU. Currently running on CPU. <br><br><b>WARNING: This model is capable of producing NSFW (Softcore) images.</b>",
32
- article = "If You Enjoyed this Demo and would like to Donate, you can send to any of these Wallets. <br>BTC: bc1qzdm9j73mj8ucwwtsjx4x4ylyfvr6kp7svzjn84 <br>3LWRoKYx6bCLnUrKEdnPo3FCSPQUSFDjFP <br>DOGE: DK6LRc4gfefdCTRk9xPD239N31jh9GjKez <br>SHIB (BEP20): 0xbE8f2f3B71DFEB84E5F7E3aae1909d60658aB891 <br>PayPal: https://www.paypal.me/ManjushriBodhisattva <br>ETH: 0xbE8f2f3B71DFEB84E5F7E3aae1909d60658aB891 <br>Code Monkey: <a href=\"https://huggingface.co/Manjushri\">Manjushri</a>").launch(debug=True, max_threads=80)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  import numpy as np
3
+ from flask import Flask, request, jsonify
4
+ from diffusers import DiffusionPipeline, DDIMScheduler
5
+ from PIL import Image
6
+ import base64
7
+ import io
8
+ import gc
9
+ import os
10
+
11
+ app = Flask(__name__)
12
+
13
+ # Device set to CPU and optimize threading
14
+ device = "cpu"
15
+ torch.set_num_threads(max(1, os.cpu_count() or 4))
16
+ torch.set_num_interop_threads(max(1, os.cpu_count() or 4))
17
+
18
+ # Load model with accelerate and disable safety checker
19
+ pipe = DiffusionPipeline.from_pretrained(
20
+ "stabilityai/sdxl-turbo",
21
+ use_safetensors=True,
22
+ low_cpu_mem_usage=True
23
+ )
24
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) # Faster scheduler
25
+ pipe.safety_checker = None # Disable safety checker
26
+ pipe = pipe.to(device)
27
+ pipe.unet.enable_model_cpu_offload()
28
+
29
+ def infer(prompt, steps=1, seed=0):
30
+ if not prompt or len(prompt.split()) > 77:
31
+ return "Prompt missing or exceeds 77 tokens!", 0
32
 
33
+ generator = torch.Generator(device=device).manual_seed(seed) if seed != 0 else torch.Generator(device=device).manual_seed(np.random.randint(0, 2**32 - 1))
34
+
35
+ with torch.no_grad():
36
+ image = pipe(
37
+ prompt=prompt,
38
+ num_inference_steps=steps,
39
+ guidance_scale=0.0,
40
+ height=512,
41
+ width=512,
42
+ generator=generator,
43
+ output_type="pil",
44
+ num_images_per_prompt=1
45
+ ).images[0]
46
+
47
+ gc.collect()
48
+ return image, seed
49
+
50
+ @app.route('/generate', methods=['POST'])
51
+ def generate():
52
+ prompt = request.form.get('prompt')
53
+ steps = int(request.form.get('steps', 1))
54
+ seed = int(request.form.get('seed', 0))
55
+
56
+ result, seed_used = infer(prompt, steps, seed)
57
+
58
+ if isinstance(result, str):
59
+ return jsonify({'error': result, 'seed': seed_used}), 400
60
+
61
+ buffered = io.BytesIO()
62
+ result.save(buffered, format="PNG")
63
+ img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
64
+
65
+ return jsonify({'image': img_str, 'seed': seed_used})
66
+
67
+ if __name__ == '__main__':
68
+ app.run(host='0.0.0.0', port=8000)