Adedoyinjames commited on
Commit
b8fd2a3
·
verified ·
1 Parent(s): 45e6e3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -1,24 +1,34 @@
1
  import io
 
2
  import torch
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from diffusers import FluxPipeline
6
  import gradio as gr
7
  from PIL import Image
 
8
 
9
  # Initialize FastAPI
10
  app = FastAPI()
11
 
 
 
 
 
 
 
 
12
  # Load Model Optimized for CPU
13
- # NOTE: "black-forest-labs/FLUX.1-schnell" is huge.
14
- # For HF Free Tier, consider a quantized version like "sayakpaul/flux.1-schnell-8bit"
15
  model_id = "black-forest-labs/FLUX.1-schnell"
16
 
 
17
  pipe = FluxPipeline.from_pretrained(
18
  model_id,
19
- torch_dtype=torch.bfloat16
 
20
  )
21
- # Vital for Free Tier: Moves parts of the model to CPU/Disk as needed
 
22
  pipe.enable_model_cpu_offload()
23
 
24
  class PromptRequest(BaseModel):
@@ -26,15 +36,16 @@ class PromptRequest(BaseModel):
26
 
27
  @app.post("/generate")
28
  def generate_api(request: PromptRequest):
 
29
  image = pipe(
30
  request.prompt,
31
- num_inference_steps=4, # Schnell is optimized for 4 steps
32
  guidance_scale=0.0
33
  ).images[0]
34
 
35
  img_byte_arr = io.BytesIO()
36
  image.save(img_byte_arr, format='PNG')
37
- return {"image": img_byte_arr.getvalue().hex()} # Or return as StreamingResponse
38
 
39
  def gradio_generate(prompt):
40
  return pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
@@ -43,14 +54,15 @@ def gradio_generate(prompt):
43
  with gr.Blocks() as demo:
44
  gr.Markdown("# FLUX.1 [schnell] CPU Explorer")
45
  with gr.Row():
46
- input_text = gr.Textbox(label="Enter Prompt")
47
  output_img = gr.Image(label="Generated Image")
48
  btn = gr.Button("Generate")
49
  btn.click(fn=gradio_generate, inputs=input_text, outputs=output_img)
50
 
51
- # Mount FastAPI into Gradio for Hugging Face compatibility
52
  app = gr.mount_gradio_app(app, demo, path="/")
53
 
54
  if __name__ == "__main__":
55
  import uvicorn
 
56
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import io
2
+ import os
3
  import torch
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from diffusers import FluxPipeline
7
  import gradio as gr
8
  from PIL import Image
9
+ from huggingface_hub import login
10
 
11
  # Initialize FastAPI
12
  app = FastAPI()
13
 
14
+ # 1. Login using the Secret stored in the Space settings
15
+ hf_token = os.getenv("HF_TOKEN")
16
+ if hf_token:
17
+ login(token=hf_token)
18
+ else:
19
+ print("Warning: HF_TOKEN not found in Secrets. Gated models may fail.")
20
+
21
  # Load Model Optimized for CPU
 
 
22
  model_id = "black-forest-labs/FLUX.1-schnell"
23
 
24
+ # Using float32 or bfloat16 for CPU compatibility
25
  pipe = FluxPipeline.from_pretrained(
26
  model_id,
27
+ torch_dtype=torch.bfloat16,
28
+ use_auth_token=True
29
  )
30
+
31
+ # Enable CPU offloading to stay within the ~16GB RAM limit
32
  pipe.enable_model_cpu_offload()
33
 
34
  class PromptRequest(BaseModel):
 
36
 
37
  @app.post("/generate")
38
  def generate_api(request: PromptRequest):
39
+ # num_inference_steps=4 is the sweet spot for Schnell
40
  image = pipe(
41
  request.prompt,
42
+ num_inference_steps=4,
43
  guidance_scale=0.0
44
  ).images[0]
45
 
46
  img_byte_arr = io.BytesIO()
47
  image.save(img_byte_arr, format='PNG')
48
+ return {"image": img_byte_arr.getvalue().hex()}
49
 
50
  def gradio_generate(prompt):
51
  return pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
 
54
  with gr.Blocks() as demo:
55
  gr.Markdown("# FLUX.1 [schnell] CPU Explorer")
56
  with gr.Row():
57
+ input_text = gr.Textbox(label="Enter Prompt", placeholder="A futuristic city in the style of cyberpunk...")
58
  output_img = gr.Image(label="Generated Image")
59
  btn = gr.Button("Generate")
60
  btn.click(fn=gradio_generate, inputs=input_text, outputs=output_img)
61
 
62
+ # Mount FastAPI into Gradio
63
  app = gr.mount_gradio_app(app, demo, path="/")
64
 
65
  if __name__ == "__main__":
66
  import uvicorn
67
+ # Port 7860 is required for Hugging Face Spaces
68
  uvicorn.run(app, host="0.0.0.0", port=7860)