Johdw commited on
Commit
a6e9875
·
verified ·
1 Parent(s): 255bc5e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -20
main.py CHANGED
@@ -89,47 +89,69 @@ def create_the_final_results(fabric: Image.Image, person: Image.Image, mask: Ima
89
  return results
90
 
91
  def load_image_from_base64(s: str, m: str = 'RGB'):
92
- if "," not in s: return None
93
- try: return Image.open(io.BytesIO(base64.b64decode(s.split(",")[1]))).convert(m)
94
- except: return None
 
 
 
 
 
 
 
95
 
96
  # === API ENDPOINTS (CORRECT AND GUARANTEED) ===
97
 
98
  @app.get("/")
99
- def root(): return {"status": "API server is running. Model will load on first call."}
100
- class ApiInput(BaseModel): person_base64: str; fabric_base64: str; mask_base64: Optional[str] = None
 
 
 
 
 
101
 
102
  @app.post("/generate")
103
  async def api_generate(request: Request, inputs: ApiInput):
104
- print( "api_generate processing.." );
105
  load_model()
 
106
  API_KEY = os.environ.get("API_KEY")
107
- if request.headers.get("x-api-key") != API_KEY: raise HTTPException(status_code=401, detail="Unauthorized")
108
- person = load_image_from_base64(inputs.person_base64); fabric = load_image_from_base64(inputs.fabric_base64)
109
- if person is None or fabric is None: raise HTTPException(status_code=400, detail="Could not decode base64.")
 
 
 
 
 
 
110
 
111
- # Process at high resolution, just like the Colab notebook.
112
  TARGET_SIZE = (1024, 1024)
113
  person_resized = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
114
 
 
115
  if inputs.mask_base64:
116
- print( "api_generate if start processing.." );
117
  mask = load_image_from_base64(inputs.mask_base64, mode='L')
118
- if mask is None: raise HTTPException(status_code=400, detail="Could not decode mask base64.")
119
- # mask = mask.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
120
- print( "api_generate if end processing.." );
121
- else:
 
122
  mask = generate_ultimate_mask(person_resized)
123
- print( "api_generate else processing.." );
124
-
125
  final_results = create_the_final_results(fabric, person_resized, mask)
126
 
 
127
  def to_base64(img):
128
- print( "to_base64 processing.." );
129
  img_display = img.resize((512, 512), Image.Resampling.LANCZOS)
130
- buf = io.BytesIO(); img_display.save(buf, format="PNG");
 
131
  return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}"
132
 
 
133
  response_data = {
134
  'ultimate_image': to_base64(final_results['ultimate_image']),
135
  'fine_weave_image': to_base64(final_results['fine_weave_image']),
@@ -137,5 +159,5 @@ async def api_generate(request: Request, inputs: ApiInput):
137
  'creative_variation_image': to_base64(final_results['creative_variation_image']),
138
  'mask_image': to_base64(mask)
139
  }
140
- print( "All Processing End" );
141
  return response_data
 
89
  return results
90
 
91
  def load_image_from_base64(s: str, m: str = 'RGB'):
92
+ """Decodes a Base64 string and opens it as a PIL Image."""
93
+ if "," not in s:
94
+ return None
95
+ try:
96
+ img_data = base64.b64decode(s.split(",")[1])
97
+ img = Image.open(io.BytesIO(img_data))
98
+ return img.convert(m)
99
+ except Exception as e:
100
+ print(f"Error loading image from base64: {e}")
101
+ return None
102
 
103
  # === API ENDPOINTS (CORRECT AND GUARANTEED) ===
104
 
105
  @app.get("/")
106
+ def root():
107
+ return {"status": "API server is running. Model will load on first call."}
108
+
109
+ class ApiInput(BaseModel):
110
+ person_base64: str
111
+ fabric_base64: str
112
+ mask_base64: Optional[str] = None
113
 
114
  @app.post("/generate")
115
  async def api_generate(request: Request, inputs: ApiInput):
116
+ # Ensure the model is loaded
117
  load_model()
118
+
119
  API_KEY = os.environ.get("API_KEY")
120
+ if request.headers.get("x-api-key") != API_KEY:
121
+ raise HTTPException(status_code=401, detail="Unauthorized")
122
+
123
+ # Load person and fabric images from base64
124
+ person = load_image_from_base64(inputs.person_base64)
125
+ fabric = load_image_from_base64(inputs.fabric_base64)
126
+
127
+ if person is None or fabric is None:
128
+ raise HTTPException(status_code=400, detail="Could not decode base64 images.")
129
 
130
+ # Resize person image to a standard size
131
  TARGET_SIZE = (1024, 1024)
132
  person_resized = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
133
 
134
+ # Handle mask image if provided
135
  if inputs.mask_base64:
 
136
  mask = load_image_from_base64(inputs.mask_base64, mode='L')
137
+ if mask is None:
138
+ raise HTTPException(status_code=400, detail="Could not decode mask base64.")
139
+ mask = mask.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
140
+ else:
141
+ # If no mask is provided, generate one
142
  mask = generate_ultimate_mask(person_resized)
143
+
144
+ # Process and create the final results
145
  final_results = create_the_final_results(fabric, person_resized, mask)
146
 
147
+ # Convert image to base64 for the response
148
  def to_base64(img):
 
149
  img_display = img.resize((512, 512), Image.Resampling.LANCZOS)
150
+ buf = io.BytesIO()
151
+ img_display.save(buf, format="PNG")
152
  return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}"
153
 
154
+ # Prepare the response data with images converted to base64
155
  response_data = {
156
  'ultimate_image': to_base64(final_results['ultimate_image']),
157
  'fine_weave_image': to_base64(final_results['fine_weave_image']),
 
159
  'creative_variation_image': to_base64(final_results['creative_variation_image']),
160
  'mask_image': to_base64(mask)
161
  }
162
+
163
  return response_data