Rx Codex AI commited on
Commit
33d3090
·
verified ·
1 Parent(s): d2b9f1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -82
app.py CHANGED
@@ -1,83 +1,85 @@
1
- # app.py
2
- from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel
4
- import torch
5
- from diffusers import AutoPipelineForText2Image
6
- from contextlib import asynccontextmanager
7
- import io
8
- import base64
9
- import os
10
-
11
- # --- Pydantic Models ---
12
- class ImageRequest(BaseModel):
13
- prompt: str
14
- negative_prompt: str = ""
15
- steps: int = 25
16
-
17
- class ImageResponse(BaseModel):
18
- image_base64: str
19
-
20
- # --- App State and Lifespan ---
21
- app_state = {}
22
-
23
- @asynccontextmanager
24
- async def lifespan(app: FastAPI):
25
- # Load the model on startup
26
- hf_token = os.getenv("HF_TOKEN")
27
- if not hf_token:
28
- raise RuntimeError("HF_TOKEN environment variable not set!")
29
-
30
- model_id = "rxmha125/sdxl-base-1.0-private" # <-- YOUR PRIVATE MODEL ID
31
- print(f"Loading model: {model_id}")
32
-
33
- pipe = AutoPipelineForText2Image.from_pretrained(
34
- model_id,
35
- torch_dtype=torch.float16,
36
- variant="fp16",
37
- use_safetensors=True,
38
- token=hf_token
39
- ).to("cuda")
40
-
41
- # Optimization for speed and memory
42
- pipe.enable_model_cpu_offload()
43
-
44
- app_state["pipe"] = pipe
45
- print("Model loaded successfully.")
46
- yield
47
- # Clean up on shutdown
48
- app_state.clear()
49
- print("Resources cleaned up.")
50
-
51
- # --- FastAPI App ---
52
- app = FastAPI(lifespan=lifespan)
53
-
54
- @app.get("/")
55
- def root():
56
- return {"status": "Text-to-Image API is running"}
57
-
58
- @app.post("/generate-image", response_model=ImageResponse)
59
- def generate_image(request: ImageRequest):
60
- if "pipe" not in app_state:
61
- raise HTTPException(status_code=503, detail="Model is not ready.")
62
-
63
- pipe = app_state["pipe"]
64
-
65
- print(f"Generating image for prompt: '{request.prompt}'")
66
- try:
67
- # Generate the image
68
- image = pipe(
69
- prompt=request.prompt,
70
- negative_prompt=request.negative_prompt,
71
- num_inference_steps=request.steps
72
- ).images[0]
73
-
74
- # Convert image to Base64
75
- buffer = io.BytesIO()
76
- image.save(buffer, format="PNG")
77
- img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
78
-
79
- return ImageResponse(image_base64=img_str)
80
-
81
- except Exception as e:
82
- print(f"Error during image generation: {e}")
 
 
83
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ # app.py
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ import torch
5
+ from diffusers import AutoPipelineForText2Image
6
+ from contextlib import asynccontextmanager
7
+ import io
8
+ import base64
9
+ import os
10
+
11
+ # --- Pydantic Models ---
12
+ class ImageRequest(BaseModel):
13
+ prompt: str
14
+ negative_prompt: str = ""
15
+ steps: int = 25
16
+
17
+ class ImageResponse(BaseModel):
18
+ image_base64: str
19
+
20
+ # --- App State and Lifespan ---
21
+ app_state = {}
22
+
23
+ @asynccontextmanager
24
+ async def lifespan(app: FastAPI):
25
+ # Load the model on startup
26
+ hf_token = os.getenv("HF_TOKEN")
27
+ if not hf_token:
28
+ raise RuntimeError("HF_TOKEN environment variable not set!")
29
+
30
+ model_id = "rxmha125/sdxl-base-1.0-private" # Your private model ID
31
+ print(f"Loading model: {model_id}")
32
+
33
+ # --- *** THIS IS THE CORRECTED PART *** ---
34
+ # We removed variant="fp16" and use_safetensors=True
35
+ # to load the available .bin files instead of the missing .safetensors.
36
+ pipe = AutoPipelineForText2Image.from_pretrained(
37
+ model_id,
38
+ torch_dtype=torch.float16, # Keep for memory optimization
39
+ token=hf_token
40
+ ).to("cuda")
41
+ # --- *********************************** ---
42
+
43
+ # Optimization for speed and memory
44
+ pipe.enable_model_cpu_offload()
45
+
46
+ app_state["pipe"] = pipe
47
+ print("Model loaded successfully.")
48
+ yield
49
+ # Clean up on shutdown
50
+ app_state.clear()
51
+ print("Resources cleaned up.")
52
+
53
+ # --- FastAPI App ---
54
+ app = FastAPI(lifespan=lifespan)
55
+
56
+ @app.get("/")
57
+ def root():
58
+ return {"status": "Text-to-Image API is running"}
59
+
60
+ @app.post("/generate-image", response_model=ImageResponse)
61
+ def generate_image(request: ImageRequest):
62
+ if "pipe" not in app_state:
63
+ raise HTTPException(status_code=503, detail="Model is not ready.")
64
+
65
+ pipe = app_state["pipe"]
66
+
67
+ print(f"Generating image for prompt: '{request.prompt}'")
68
+ try:
69
+ # Generate the image
70
+ image = pipe(
71
+ prompt=request.prompt,
72
+ negative_prompt=request.negative_prompt,
73
+ num_inference_steps=request.steps
74
+ ).images[0]
75
+
76
+ # Convert image to Base64
77
+ buffer = io.BytesIO()
78
+ image.save(buffer, format="PNG")
79
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
80
+
81
+ return ImageResponse(image_base64=img_str)
82
+
83
+ except Exception as e:
84
+ print(f"Error during image generation: {e}")
85
  raise HTTPException(status_code=500, detail=str(e))