File size: 5,299 Bytes
25b0876
 
 
 
 
 
94166a6
 
85062cc
 
 
 
 
94166a6
 
 
 
25b0876
 
94166a6
 
85062cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94166a6
 
 
85062cc
94166a6
 
 
 
25b0876
 
 
 
 
94166a6
 
 
 
 
25b0876
 
 
 
94166a6
25b0876
 
 
94166a6
 
 
 
 
 
 
25b0876
 
94166a6
 
25b0876
94166a6
25b0876
 
 
94166a6
25b0876
94166a6
 
 
25b0876
 
 
94166a6
 
25b0876
 
94166a6
 
85062cc
94166a6
85062cc
 
94166a6
85062cc
25b0876
132b457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25b0876
94166a6
 
 
 
25b0876
94166a6
 
25b0876
 
94166a6
 
25b0876
 
 
94166a6
85062cc
 
25b0876
94166a6
 
25b0876
 
 
94166a6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from flask import Flask, request, jsonify
from flask_cors import CORS
import base64
import io
from PIL import Image
import requests
import logging

# Import the withoutBG library (correct way from Qiita article)
from withoutbg.core import WithoutBGOpenSource
from huggingface_hub import hf_hub_download
from pathlib import Path
import shutil

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = Flask(__name__)
CORS(app)

# Model directory
MODEL_DIR = Path("/app/models")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

def _ensure_model_file(filename: str) -> Path:
    """Download model file from HuggingFace if not exists"""
    target = MODEL_DIR / filename
    if target.exists():
        return target
    logger.info(f"πŸ“₯ Downloading model file: {filename}")
    downloaded = Path(hf_hub_download(repo_id="withoutbg/focus", filename=filename))
    shutil.copy2(downloaded, target)
    logger.info(f"βœ… Model file downloaded: {filename}")
    return target

def _create_model() -> WithoutBGOpenSource:
    """Create WithoutBG model instance"""
    logger.info("πŸš€ Creating WithoutBG model...")
    return WithoutBGOpenSource(
        depth_model_path=_ensure_model_file("depth_anything_v2_vits_slim.onnx"),
        isnet_model_path=_ensure_model_file("isnet.onnx"),
        matting_model_path=_ensure_model_file("focus_matting_1.0.0.onnx"),
        refiner_model_path=_ensure_model_file("focus_refiner_1.0.0.onnx"),
    )

# Initialize the model once at startup
try:
    logger.info("πŸš€ Loading withoutBG model...")
    model = _create_model()
    logger.info("βœ… Model loaded successfully!")
except Exception as e:
    logger.error(f"❌ Failed to load model: {e}")
    model = None

@app.route('/', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({
        'service': 'withoutBG API Server',
        'status': 'healthy' if model else 'unhealthy',
        'model': 'withoutBG Focus v1.0.0',
        'version': '1.0.0',
        'platform': 'Hugging Face Spaces'
    })

@app.route('/api/remove-bg', methods=['POST'])
def remove_background():
    """Remove background from image"""
    try:
        data = request.get_json()
        
        if not data:
            return jsonify({'success': False, 'error': 'No JSON data provided'}), 400
        
        if not model:
            return jsonify({'success': False, 'error': 'Model not initialized'}), 500
        
        # Get image from URL or base64
        if 'image_url' in data:
            # Download image from URL
            logger.info(f"πŸ“₯ Downloading image from URL: {data['image_url']}")
            response = requests.get(data['image_url'], timeout=30)
            response.raise_for_status()
            image_data = io.BytesIO(response.content)
            
        elif 'image_base64' in data:
            # Decode base64 image
            logger.info("πŸ“₯ Decoding base64 image")
            image_base64 = data['image_base64']
            if ',' in image_base64:
                image_base64 = image_base64.split(',')[1]
            image_data = io.BytesIO(base64.b64decode(image_base64))
            
        else:
            return jsonify({
                'success': False,
                'error': 'Either image_url or image_base64 is required'
            }), 400
        
        # Open image
        img = Image.open(image_data)
        logger.info(f"πŸ–ΌοΈ Image loaded: {img.size}, mode: {img.mode}")
        
        # Remove background using withoutBG (Qiita article method)
        logger.info("πŸ”„ Removing background with WithoutBGOpenSource...")
        result = model.remove_background(img)
        logger.info(f"βœ… Background removed! Result mode: {result.mode}, Size: {result.size}")
        
        # Convert to RGBA first if not already
        if result.mode != 'RGBA':
            result = result.convert('RGBA')
            logger.info(f"πŸ”„ Converted to RGBA mode")
        
        # Create white background and composite
        logger.info("🎨 Creating white background composite...")
        white_bg = Image.new('RGBA', result.size, (255, 255, 255, 255))
        
        # Composite the image onto white background
        output = Image.alpha_composite(white_bg, result)
        
        # Convert to RGB (remove alpha channel)
        result = output.convert('RGB')
        logger.info(f"βœ… Final image mode: {result.mode}")
        
        # Convert to PNG bytes
        output_buffer = io.BytesIO()
        result.save(output_buffer, format='PNG')
        output_buffer.seek(0)
        
        # Encode as base64
        image_base64 = base64.b64encode(output_buffer.read()).decode('utf-8')
        
        return jsonify({
            'success': True,
            'image_data': f'data:image/png;base64,{image_base64}'
        })
        
    except Exception as e:
        logger.error(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({
            'success': False,
            'error': str(e)
        }), 500

if __name__ == '__main__':
    import os
    port = int(os.environ.get('PORT', 7860))
    logger.info(f"πŸš€ Starting withoutBG API Server on port {port}...")
    app.run(host='0.0.0.0', port=port, debug=False)