AkashKumarave commited on
Commit
0d3d566
·
verified ·
1 Parent(s): 1d64ede

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -15
app.py CHANGED
@@ -1,4 +1,6 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Form # Add Form import
 
 
2
  from pydantic import BaseModel
3
  import logging
4
  import torch
@@ -16,8 +18,18 @@ logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
19
  pipe = None
20
- device = "cpu"
21
 
22
  def initialize_pipeline():
23
  global pipe
@@ -28,11 +40,11 @@ def initialize_pipeline():
28
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
29
  model_id,
30
  scheduler=scheduler,
31
- torch_dtype=torch.float32,
32
  low_cpu_mem_usage=True
33
  )
34
  pipe = pipe.to(device)
35
- logger.info("Stable Diffusion pipeline initialized successfully.")
36
  except Exception as e:
37
  logger.error(f"Failed to initialize pipeline: {str(e)}", exc_info=True)
38
  raise
@@ -98,7 +110,7 @@ def overlay_face(generated_img, face_img, face_coords):
98
  def image_to_base64(image: Image.Image) -> str:
99
  try:
100
  buffered = BytesIO()
101
- image.save(buffered, format="PNG")
102
  img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
103
  logger.info("Image converted to base64 successfully.")
104
  return img_base64
@@ -111,21 +123,33 @@ async def predict(
111
  prompt: str = Form(...),
112
  image: UploadFile = File(...),
113
  negative_prompt: str = Form("low quality, blurry"),
114
- seed: int = Form(66),
115
- guidance_scale: float = Form(7.5),
116
- num_inference_steps: int = Form(10),
117
- strength: float = Form(0.75)
118
  ):
119
  global pipe
120
  try:
121
  if pipe is None:
 
122
  raise HTTPException(status_code=500, detail="Pipeline not initialized.")
123
 
124
  logger.info(f"Received inference request with prompt: {prompt}")
125
 
 
 
 
 
 
 
 
 
 
 
126
  # Load and process uploaded image
127
  logger.info("Loading uploaded image...")
128
- ref_image = Image.open(image.file).convert("RGB")
 
129
 
130
  # Extract face
131
  logger.info("Extracting face...")
@@ -159,11 +183,9 @@ async def predict(
159
  result_base64 = image_to_base64(final_img)
160
 
161
  logger.info("Inference completed successfully.")
162
- return {
163
- "original_image": "uploaded_file",
164
- "prompt": prompt,
165
- "result_image": f"data:image/png;base64,{result_base64}"
166
- }
167
  except HTTPException as e:
168
  logger.error(f"HTTP Exception: {str(e)}")
169
  raise
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from pydantic import BaseModel
5
  import logging
6
  import torch
 
18
  logger = logging.getLogger(__name__)
19
 
20
  app = FastAPI()
21
+
22
+ # Add CORS middleware to allow requests from Framer
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"], # In production, restrict to your Framer domain
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
  pipe = None
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
  def initialize_pipeline():
35
  global pipe
 
40
  pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
41
  model_id,
42
  scheduler=scheduler,
43
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
44
  low_cpu_mem_usage=True
45
  )
46
  pipe = pipe.to(device)
47
+ logger.info(f"Stable Diffusion pipeline initialized successfully on {device}.")
48
  except Exception as e:
49
  logger.error(f"Failed to initialize pipeline: {str(e)}", exc_info=True)
50
  raise
 
110
  def image_to_base64(image: Image.Image) -> str:
111
  try:
112
  buffered = BytesIO()
113
+ image.save(buffered, format="JPEG") # Changed to JPEG to match Framer client expectation
114
  img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
115
  logger.info("Image converted to base64 successfully.")
116
  return img_base64
 
123
  prompt: str = Form(...),
124
  image: UploadFile = File(...),
125
  negative_prompt: str = Form("low quality, blurry"),
126
+ seed: str = Form("66"), # Changed to str to match Framer client
127
+ guidance_scale: str = Form("7.5"), # Changed to str to match Framer client
128
+ num_inference_steps: str = Form("10"), # Changed to str to match Framer client
129
+ strength: str = Form("0.75") # Changed to str to match Framer client
130
  ):
131
  global pipe
132
  try:
133
  if pipe is None:
134
+ logger.error("Pipeline not initialized.")
135
  raise HTTPException(status_code=500, detail="Pipeline not initialized.")
136
 
137
  logger.info(f"Received inference request with prompt: {prompt}")
138
 
139
+ # Convert string parameters to appropriate types
140
+ try:
141
+ seed = int(seed)
142
+ guidance_scale = float(guidance_scale)
143
+ num_inference_steps = int(num_inference_steps)
144
+ strength = float(strength)
145
+ except ValueError as e:
146
+ logger.error(f"Invalid parameter format: {str(e)}")
147
+ raise HTTPException(status_code=400, detail=f"Invalid parameter format: {str(e)}")
148
+
149
  # Load and process uploaded image
150
  logger.info("Loading uploaded image...")
151
+ image_data = await image.read()
152
+ ref_image = Image.open(BytesIO(image_data)).convert("RGB")
153
 
154
  # Extract face
155
  logger.info("Extracting face...")
 
183
  result_base64 = image_to_base64(final_img)
184
 
185
  logger.info("Inference completed successfully.")
186
+ return JSONResponse({
187
+ "result_image": f"data:image/jpeg;base64,{result_base64}" # Match Framer client expectation
188
+ })
 
 
189
  except HTTPException as e:
190
  logger.error(f"HTTP Exception: {str(e)}")
191
  raise