Spaces:
Running
Running
File size: 11,251 Bytes
5ab0443 e2eb2c5 5ab0443 2bb163a 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 6e7d66e e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 6e7d66e e2eb2c5 5ab0443 6e7d66e 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 e2eb2c5 5ab0443 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import StreamingResponse
import io
import logging
import torch
import gradio as gr
# Logging setup (Space logs mein clear dikhega)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ββ Global variables (model loading) βββββββββββββββββββββββββββββββββββββββββββ
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = None
model = None
hair_class_id = 13
ear_class_ids = [7, 8]
@app.on_event("startup")
async def startup_event():
global processor, model
logger.info(f"Loading SegFormer model on {device}...")
processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)
model.eval()
logger.info("Model loaded successfully!")
app = FastAPI(
title="Make Me Bald API π",
description="Upload photo β Get realistic bald version! π§βπ¦²",
version="1.0"
)
def make_realistic_bald(image_bytes: bytes) -> bytes:
"""
Main bald processing function - takes bytes, returns bald image bytes
(Updated to reduce halo/shadow artifacts)
"""
try:
# Convert bytes to PIL Image
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
orig_w, orig_h = image.size
original_np = np.array(image)
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
logger.info(f"Processing image: {orig_w}x{orig_h}")
# Resize for processing (speed + memory)
MAX_PROCESS_DIM = 2048
scale_factor = 1.0
working_np = original_np
working_bgr = original_bgr
working_h, working_w = orig_h, orig_w
if max(orig_w, orig_h) > MAX_PROCESS_DIM:
scale_factor = MAX_PROCESS_DIM / max(orig_w, orig_h)
working_w = int(orig_w * scale_factor)
working_h = int(orig_h * scale_factor)
working_np = cv2.resize(original_np, (working_w, working_h), cv2.INTER_AREA)
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
# Segmentation
pil_working = Image.fromarray(working_np)
inputs = processor(images=pil_working, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled_logits = torch.nn.functional.interpolate(
logits, size=(working_h, working_w), mode="bilinear", align_corners=False
)
parsing = upsampled_logits.argmax(dim=1).squeeze(0).cpu().numpy()
hair_mask = (parsing == hair_class_id).astype(np.uint8)
# Ear protection logic (same as yours)
ears_mask = np.zeros_like(hair_mask)
for cls in ear_class_ids:
ears_mask[parsing == cls] = 1
ear_y, ear_x = np.where(ears_mask)
ears_protected = np.zeros_like(hair_mask)
if len(ear_y) > 0:
ear_top_y = ear_y.min()
ear_height = ear_y.max() - ear_top_y + 1
kernel_v = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 30))
ears_protected = cv2.dilate(ears_mask, kernel_v, iterations=2)
top_margin = max(8, int(ear_height * 0.12))
top_start = max(0, ear_top_y - top_margin)
ear_x_min, ear_x_max = ear_x.min(), ear_x.max()
ear_width = ear_x_max - ear_x_min + 1
x_margin = int(ear_width * 0.35)
protected_left = max(0, ear_x_min - x_margin)
protected_right = min(working_w, ear_x_max + x_margin)
limited_top_mask = np.zeros_like(ears_mask)
limited_top_mask[top_start:ear_top_y + 8, protected_left:protected_right] = 1
kernel_h = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (17, 5))
limited_top_mask = cv2.dilate(limited_top_mask, kernel_h, iterations=1)
ears_protected = np.logical_or(ears_protected, limited_top_mask).astype(np.uint8)
hair_above_ears = np.zeros_like(hair_mask)
above_ear_line = max(0, ear_top_y - int(ear_height * 0.65))
hair_above_ears[:above_ear_line, :] = hair_mask[:above_ear_line, :]
ears_protected[hair_above_ears == 1] = 0
hair_mask_final = hair_mask.copy()
hair_mask_final[ears_protected == 1] = 0
if hair_mask[:int(working_h * 0.25), :].sum() > 60:
hair_mask_final[:int(working_h * 0.25), :] = np.maximum(
hair_mask_final[:int(working_h * 0.25), :], hair_mask[:int(working_h * 0.25), :]
)
# Sharper mask: reduced blur
kernel_s = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
hair_mask_final = cv2.morphologyEx(hair_mask_final, cv2.MORPH_CLOSE, kernel_s, iterations=1)
hair_mask_final = cv2.dilate(hair_mask_final, kernel_s, iterations=1)
blurred = cv2.GaussianBlur(hair_mask_final.astype(np.float32), (5, 5), 1.0)
hair_mask_final = (blurred > 0.45).astype(np.uint8) # higher threshold β sharper edges
kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
hair_mask_final = cv2.dilate(hair_mask_final, kernel_edge, iterations=1)
hair_pixels = np.sum(hair_mask_final)
logger.info(f"Hair pixels detected (resized): {hair_pixels:,}")
# Extended mask (same logic)
final_mask = hair_mask_final.copy()
use_extended_mask = False
if hair_pixels > 380000:
logger.info("Very large hair area β using extended mask")
use_extended_mask = True
big_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21, 21))
extended = cv2.dilate(hair_mask_final, big_kernel, iterations=1)
upper = np.zeros_like(hair_mask_final)
upper_end = int(working_h * 0.48)
upper[:upper_end, :] = 1
extended = np.logical_or(extended, upper).astype(np.uint8)
extended[ears_protected == 1] = 0
extended = cv2.morphologyEx(extended, cv2.MORPH_CLOSE, kernel_s, iterations=1)
extended[int(working_h * 0.75):, :] = 0
final_mask = extended
# Inpainting - reduced radius for less halo
radius = 8 if use_extended_mask or hair_pixels > 300000 else 5
inpaint_flag = cv2.INPAINT_TELEA # better boundary preservation
logger.info(f"Inpainting with radius={radius}")
inpainted_bgr = cv2.inpaint(working_bgr, final_mask * 255, inpaintRadius=radius, flags=inpaint_flag)
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
result_small = working_np.copy()
result_small[final_mask == 1] = inpainted_rgb[final_mask == 1]
# Light color matching (reduced strength)
if use_extended_mask or hair_pixels > 200000:
logger.info("Applying light skin color correction")
regions = [(0.20, 0.35, 0.35, 0.65), (0.35, 0.50, 0.35, 0.65)]
colors = []
for y1r, y2r, x1r, x2r in regions:
y1, y2 = int(working_h * y1r), int(working_h * y2r)
x1, x2 = int(working_w * x1r), int(working_w * x2r)
if y2 > y1 + 40 and x2 > x1 + 80:
crop = working_np[y1:y2, x1:x2]
if crop.size > 0:
colors.append(np.median(crop, axis=(0,1)).astype(np.float32))
if colors:
target_color = np.mean(colors, axis=0)
strength = 0.45 # reduced to avoid artifacts
bald_area = result_small[final_mask == 1].astype(np.float32)
if len(bald_area) > 200:
current_mean = bald_area.mean(axis=0)
diff = target_color - current_mean
corrected = np.clip(bald_area + diff * strength, 0, 255).astype(np.uint8)
result_small[final_mask == 1] = corrected
# Sharpen to remove residual blur/halo
sharpen_kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
result_small = cv2.filter2D(result_small, -1, sharpen_kernel)
# Upscale if resized
if scale_factor < 1.0:
logger.info("Upscaling to original size")
result = cv2.resize(result_small, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
else:
result = result_small
# Convert back to bytes
_, buffer = cv2.imencode('.jpg', cv2.cvtColor(result, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), 92])
return buffer.tobytes()
except Exception as e:
logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
raise ValueError(f"Processing error: {str(e)}")
@app.post("/make-bald/")
async def bald_endpoint(file: UploadFile = File(...)):
logger.info("=== REQUEST AAYI /make-bald/ PE ===")
logger.info(f"Filename: {file.filename} | Content-Type: {file.content_type} | Size: {file.size / 1024:.2f} KB")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Sirf image file upload kar bhai! (jpeg/png etc.)")
try:
contents = await file.read()
logger.info(f"Image read successful, size: {len(contents) / 1024:.2f} KB")
bald_bytes = make_realistic_bald(contents)
logger.info(f"Bald processing done, output size: {len(bald_bytes) / 1024:.2f} KB")
bald_io = io.BytesIO(bald_bytes)
bald_io.seek(0)
return StreamingResponse(
bald_io,
media_type="image/jpeg",
headers={"Content-Disposition": "attachment; filename=bald_version.jpg"}
)
except ValueError as ve:
error_detail = str(ve).strip()
logger.warning(f"ValueError: {error_detail}")
if "NO_HAIR" in error_detail.upper() or "NO_HAIR_DETECTED" in error_detail.upper():
raise HTTPException(status_code=400, detail="NO_HAIR_DETECTED")
raise HTTPException(status_code=400, detail=error_detail or "Processing mein kuch galat hua")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
@app.get("/")
def home():
return {
"message": "Bald banne aaya? π",
"how_to_use": "POST request bhejo /make-bald/ pe with form-data key 'file' aur image attach karo.",
"example": "curl -X POST -F 'file=@your_photo.jpg' https://seniordev22-space.hf.space/make-bald/ -o bald.jpg"
}
# Gradio dummy for HF Spaces
def dummy_fn():
return "API chal raha hai! cURL ya Postman se /make-bald/ pe POST karo."
gr_interface = gr.Interface(
fn=dummy_fn,
inputs=None,
outputs="text",
title="Make Me Bald API π§βπ¦²",
description="Ye sirf info page hai. Actual bald banane ke liye:\n\ncurl -X POST -F 'file=@photo.jpg' https://seniordev22-space.hf.space/make-bald/ -o bald.jpg"
)
# Mount Gradio on root path (HF Spaces compatibility ke liye)
app = gr.mount_gradio_app(app, gr_interface, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |