speedartificialintelligence1122 commited on
Commit
2844088
·
verified ·
1 Parent(s): 814e862

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -1,40 +1,34 @@
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
6
- import uuid
7
- import base64
8
  from io import BytesIO
9
  from PIL import Image
 
10
 
11
  app = FastAPI()
12
 
13
- # Load the model (NO fp16 issues now)
14
- pipe = StableDiffusionPipeline.from_pretrained(
15
- "runwayml/stable-diffusion-v1-5"
16
- )
17
- pipe = pipe.to("cpu") # Or use .to("cuda") if you're on GPU
18
 
19
- # For receiving prompts from the frontend
20
  class Prompt(BaseModel):
21
  text: str
22
 
23
  @app.get("/")
24
- def greet_json():
25
- return {"message": "Text to Image generation ready!"}
26
 
27
  @app.post("/generate")
28
- def generate_image(prompt: Prompt):
29
  image = pipe(prompt.text).images[0]
30
-
31
- # Convert image to base64 to send over HTTP
32
- buffered = BytesIO()
33
- image.save(buffered, format="PNG")
34
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
35
-
36
- return {
37
- "image_base64": img_str,
38
- "status": "success",
39
- "prompt": prompt.text
40
- }
 
1
+ import os
2
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
3
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
4
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
5
+
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  from diffusers import StableDiffusionPipeline
9
  import torch
10
 
 
 
11
  from io import BytesIO
12
  from PIL import Image
13
+ import base64
14
 
15
  app = FastAPI()
16
 
17
+ # Load the model safely
18
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
19
+ pipe = pipe.to("cpu")
 
 
20
 
 
21
  class Prompt(BaseModel):
22
  text: str
23
 
24
  @app.get("/")
25
+ def greet():
26
+ return {"message": "Model ready"}
27
 
28
  @app.post("/generate")
29
+ def generate(prompt: Prompt):
30
  image = pipe(prompt.text).images[0]
31
+ buffer = BytesIO()
32
+ image.save(buffer, format="PNG")
33
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
34
+ return {"image_base64": img_str, "prompt": prompt.text}