dragonranvir commited on
Commit
eeb36ab
·
verified ·
1 Parent(s): 5c09835

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -28
app.py CHANGED
@@ -1,35 +1,47 @@
1
- import os
2
- import torch
3
  from diffusers import StableDiffusionPipeline
 
 
4
 
5
- # Read HF token from Spaces Secrets
6
- HF_TOKEN = os.getenv("HF_TOKEN")
7
 
8
- if HF_TOKEN is None:
9
- raise RuntimeError("HF_TOKEN not found. Add it in HF Spaces → Secrets.")
10
-
11
- model_id = "prompthero/openjourney"
12
 
13
  pipe = StableDiffusionPipeline.from_pretrained(
14
- model_id,
15
- token=HF_TOKEN,
16
- torch_dtype=torch.float32, # REQUIRED for CPU
17
- safety_checker=None # reduce RAM
18
  )
19
 
20
- pipe.to("cpu")
21
-
22
- prompt = "retro serie of different cars with different colors and shapes, mdjrny-v4 style"
23
-
24
- print("Starting image generation (CPU-only, this is slow)...")
25
-
26
- image = pipe(
27
- prompt,
28
- height=256, # DO NOT increase
29
- width=256,
30
- num_inference_steps=10, # DO NOT increase
31
- guidance_scale=7.0
32
- ).images[0]
33
-
34
- image.save("output.png")
35
- print("Saved output.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
  from diffusers import StableDiffusionPipeline
4
+ import torch, base64, io
5
+ from PIL import Image
6
 
7
+ app = FastAPI()
 
8
 
9
+ MODEL_ID = "prompthero/openjourney-v4"
 
 
 
10
 
11
  pipe = StableDiffusionPipeline.from_pretrained(
12
+ MODEL_ID,
13
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
14
+ safety_checker=None,
15
+ requires_safety_checker=False
16
  )
17
 
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ pipe = pipe.to(device)
20
+
21
+ # VERY IMPORTANT for speed & stability
22
+ pipe.enable_attention_slicing()
23
+ pipe.enable_vae_slicing()
24
+
25
+ class GenerateRequest(BaseModel):
26
+ prompt: str
27
+ steps: int = 20
28
+ width: int = 512
29
+ height: int = 512
30
+
31
+ @app.post("/generate")
32
+ def generate(req: GenerateRequest):
33
+ image = pipe(
34
+ req.prompt,
35
+ num_inference_steps=req.steps,
36
+ width=req.width,
37
+ height=req.height
38
+ ).images[0]
39
+
40
+ buf = io.BytesIO()
41
+ image.save(buf, format="PNG")
42
+ img_base64 = base64.b64encode(buf.getvalue()).decode()
43
+
44
+ return {
45
+ "status": "ok",
46
+ "image_base64": img_base64
47
+ }