yusufgundogdu's picture
Update app.py
78b02fb verified
raw
history blame
7.24 kB
import os
import sys
import io
import logging
from pathlib import Path
from datetime import datetime
from flask import Flask, request, send_file, jsonify
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as transforms
from database import init_db, close_db, get_db_path
# Proje kök dizinini Python path'ine ekle
sys.path.insert(0, str(Path(__file__).parent))
from demo_generate.halftone_method import apply_halftone
from demo_generate.animegan_method import apply_animegan # Yeni eklenen
from demo_generate.rembg_method import remove_background # Yeni eklenen
# Environment setup
os.environ['NUMBA_DISABLE_JIT'] = '1'
os.environ['TORCH_HOME'] = '/tmp/torch_cache'
os.environ['U2NET_HOME'] = '/tmp/.u2net'
# Create cache directories
os.makedirs('/tmp/torch_cache', exist_ok=True)
os.makedirs('/tmp/.u2net', exist_ok=True)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Flask
app = Flask(__name__)
init_db(app)
# --- Model Initialization ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# AnimeGAN model yükleme (orijinal kodunuz aynen korundu)
try:
logger.info("Loading AnimeGAN model...")
model = torch.hub.load(
'bryandlee/animegan2-pytorch:main',
'generator',
pretrained='face_paint_512_v2',
trust_repo=True
).to(device)
model.eval()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise
# --- Image Processing Functions ---
def process_image(image):
"""Convert image to anime style"""
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
with torch.no_grad():
output = model(transform(image).unsqueeze(0).to(device))
return transforms.ToPILImage()((output * 0.5 + 0.5).squeeze().cpu())
def remove_background_original(image_data): # İsim değiştirildi
"""Remove image background"""
from rembg import remove
input_image = Image.open(io.BytesIO(image_data))
if input_image.mode != 'RGB':
input_image = input_image.convert('RGB')
return remove(input_image)
# --- API Routes ---
@app.route('/')
def home():
return "Anime Processing API - Türkçe"
# Yeni eklenen endpoint: /generate-v2
@app.route('/generate-v2', methods=['POST'])
def generate_v2():
try:
if 'image' not in request.files:
return jsonify({'error': 'Resim yüklenmedi'}), 400
file = request.files['image']
if not file.filename:
return jsonify({'error': 'Boş dosya'}), 400
image = Image.open(io.BytesIO(file.read())).convert("RGB")
processed_img = apply_animegan(image) # Yeni modül kullanılıyor
img_io = io.BytesIO()
processed_img.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Hata: {str(e)}")
return jsonify({'error': str(e)}), 500
# Orijinal endpoint'ler aynen korundu
@app.route('/generate', methods=['POST'])
def generate():
try:
if 'image' not in request.files:
return jsonify({'error': 'Resim yüklenmedi'}), 400
file = request.files['image']
if not file.filename:
return jsonify({'error': 'Boş dosya'}), 400
image = Image.open(io.BytesIO(file.read())).convert("RGB")
processed_img = process_image(image)
img_io = io.BytesIO()
processed_img.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Hata: {str(e)}")
return jsonify({'error': str(e)}), 500
# Yeni eklenen endpoint: /remove-bg-v2
@app.route('/remove-bg-v2', methods=['POST'])
def bg_remove_v2():
try:
if 'image' not in request.files:
return jsonify({'error': 'Resim yüklenmedi'}), 400
file = request.files['image']
file.seek(0, os.SEEK_END)
file_size = file.tell()
file.seek(0)
if file_size > 10 * 1024 * 1024:
return jsonify({'error': 'Dosya boyutu 10MB üzerinde'}), 400
output = remove_background(file.read()) # Yeni modül kullanılıyor
img_io = io.BytesIO()
output.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Arkaplan kaldırma hatası: {str(e)}")
return jsonify({'error': str(e)}), 500
# Orijinal endpoint aynen korundu
@app.route('/remove-bg', methods=['POST'])
def bg_remove():
try:
if 'image' not in request.files:
return jsonify({'error': 'Resim yüklenmedi'}), 400
file = request.files['image']
file.seek(0, os.SEEK_END)
file_size = file.tell()
file.seek(0)
if file_size > 10 * 1024 * 1024:
return jsonify({'error': 'Dosya boyutu 10MB üzerinde'}), 400
output = remove_background_original(file.read())
img_io = io.BytesIO()
output.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Arkaplan kaldırma hatası: {str(e)}")
return jsonify({'error': str(e)}), 500
# Orijinal halftone endpoint'i aynen korundu
@app.route('/halftone', methods=['POST'])
def halftone_route():
try:
if 'image' not in request.files:
return jsonify({'error': 'Resim yüklenmedi'}), 400
file = request.files['image']
dot_size = int(request.form.get('dot_size', 10))
dot_size = max(5, min(20, dot_size))
image = Image.open(io.BytesIO(file.read()))
output = apply_halftone(image, dot_size)
img_io = io.BytesIO()
output.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Halftone hatası: {str(e)}")
return jsonify({'error': str(e)}), 500
# --- Database Routes --- (Aynen korundu)
@app.route('/users', methods=['GET'])
def get_users_route():
from get_methods import get_users
return get_users()
@app.route('/user/<string:udid>', methods=['GET'])
def get_user_route(udid):
from get_methods import get_user
return get_user(udid)
@app.route('/add-user', methods=['POST'])
def add_user_route():
from post_methods import add_user
return add_user()
@app.route('/consume/<string:udid>', methods=['POST'])
def consume_user_route(udid):
from consume_method import consume_user
return consume_user(udid)
# --- Teardown --- (Aynen korundu)
@app.teardown_appcontext
def shutdown_session(exception=None):
close_db()
if __name__ == '__main__':
db_path = get_db_path()
os.makedirs(os.path.dirname(db_path), exist_ok=True)
app.run(host='0.0.0.0', port=7860, debug=True)