yusufgundogdu commited on
Commit
8cd6e29
·
verified ·
1 Parent(s): 08df5bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -61
app.py CHANGED
@@ -1,21 +1,20 @@
1
- from flask import Flask, request, send_file, jsonify
2
  import os
3
  import sys
4
  import io
5
  import logging
6
- import torch
7
  from datetime import datetime
 
 
 
 
8
  from database import init_db, close_db, get_db_path
9
 
10
- # Import yollarını kontrol et
11
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12
-
13
  # Environment setup
14
  os.environ['NUMBA_DISABLE_JIT'] = '1'
15
  os.environ['TORCH_HOME'] = '/tmp/torch_cache'
16
  os.environ['U2NET_HOME'] = '/tmp/.u2net'
17
 
18
- # Create cache directories if they don't exist
19
  os.makedirs('/tmp/torch_cache', exist_ok=True)
20
  os.makedirs('/tmp/.u2net', exist_ok=True)
21
 
@@ -23,99 +22,98 @@ os.makedirs('/tmp/.u2net', exist_ok=True)
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
- # Model Loading (Çalışan örnekten eklenen kısım)
 
 
 
 
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
  try:
29
- logger.info("Loading model...")
30
- model = torch.hub.load('bryandlee/animegan2-pytorch:main', 'generator', trust_repo=True).to(device)
31
- model.load_state_dict(torch.load('face_paint_512_v2.pt', map_location=device))
 
 
 
 
32
  model.eval()
33
  logger.info("Model loaded successfully")
34
  except Exception as e:
35
- logger.error(f"Error loading model: {str(e)}")
36
- raise
37
-
38
- # Import methods
39
- try:
40
- from get_methods import get_users, get_user
41
- from post_methods import add_user
42
- from consume_method import consume_user
43
- from rembg_method import remove_background
44
- from halftone_method import apply_halftone
45
- except ImportError as e:
46
- logger.error(f"Import hatası: {e}")
47
  raise
48
 
49
- app = Flask(__name__)
50
- init_db(app)
 
 
 
 
 
 
 
 
51
 
 
52
  @app.route('/')
53
  def home():
54
- return "StableDiffusionAPI Türkçe"
55
-
56
- # Existing routes
57
- app.route('/users', methods=['GET'])(get_users)
58
- app.route('/user/<string:udid>', methods=['GET'])(get_user)
59
- app.route('/add-user', methods=['POST'])(add_user)
60
- app.route('/consume/<string:udid>', methods=['POST'])(consume_user)
61
 
62
- # New image processing routes (Çalışan örnekten güncellenen kısım)
63
  @app.route('/generate', methods=['POST'])
64
  def generate():
65
  start_time = datetime.now()
66
- logger.info(f"/generate endpoint called - {start_time}")
67
 
68
  try:
69
  if 'image' not in request.files:
70
- logger.error("No file provided")
71
  return jsonify({'error': 'No image provided'}), 400
72
 
73
  file = request.files['image']
74
- if file.filename == '':
75
- logger.error("Empty file provided")
76
- return jsonify({'error': 'No selected file'}), 400
77
-
78
  try:
79
- logger.info(f"File received: {file.filename}")
80
  image = Image.open(io.BytesIO(file.read())).convert("RGB")
81
-
82
- logger.info("Processing image...")
83
- transform = transforms.Compose([
84
- transforms.Resize((512, 512)),
85
- transforms.ToTensor(),
86
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
87
- ])
88
- with torch.no_grad():
89
- output = model(transform(image).unsqueeze(0).to(device))
90
- processed_img = transforms.ToPILImage()((output * 0.5 + 0.5).squeeze().cpu())
91
 
92
  img_io = io.BytesIO()
93
  processed_img.save(img_io, 'PNG')
94
  img_io.seek(0)
95
 
96
  duration = (datetime.now() - start_time).total_seconds()
97
- logger.info(f"Successfully processed. Duration: {duration} seconds")
98
  return send_file(img_io, mimetype='image/png')
99
 
100
  except Exception as e:
101
- logger.error(f"Processing error: {str(e)}", exc_info=True)
102
  return jsonify({'error': str(e)}), 500
103
 
104
  except Exception as e:
105
- logger.error(f"Unexpected error: {str(e)}", exc_info=True)
106
  return jsonify({'error': 'Internal server error'}), 500
107
 
108
- # Diğer endpoint'ler (remove-bg ve halftone) aynı kalıyor...
109
- @app.route('/remove-bg', methods=['POST'])
110
- def remove_bg():
111
- # ... (mevcut kodunuz aynen kalacak)
112
- pass
 
 
 
 
 
113
 
114
- @app.route('/halftone', methods=['POST'])
115
- def halftone():
116
- # ... (mevcut kodunuz aynen kalacak)
117
- pass
 
 
 
 
 
118
 
 
119
  @app.teardown_appcontext
120
  def shutdown_session(exception=None):
121
  close_db()
 
 
1
  import os
2
  import sys
3
  import io
4
  import logging
 
5
  from datetime import datetime
6
+ from flask import Flask, request, send_file, jsonify
7
+ from PIL import Image
8
+ import torch
9
+ import torchvision.transforms as transforms
10
  from database import init_db, close_db, get_db_path
11
 
 
 
 
12
  # Environment setup
13
  os.environ['NUMBA_DISABLE_JIT'] = '1'
14
  os.environ['TORCH_HOME'] = '/tmp/torch_cache'
15
  os.environ['U2NET_HOME'] = '/tmp/.u2net'
16
 
17
+ # Create cache directories
18
  os.makedirs('/tmp/torch_cache', exist_ok=True)
19
  os.makedirs('/tmp/.u2net', exist_ok=True)
20
 
 
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
+ # Initialize Flask
26
+ app = Flask(__name__)
27
+ init_db(app)
28
+
29
+ # --- AnimeGAN Model Initialization ---
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
  try:
33
+ logger.info("Loading AnimeGAN model...")
34
+ model = torch.hub.load(
35
+ 'bryandlee/animegan2-pytorch:main',
36
+ 'generator',
37
+ pretrained='face_paint_512_v2', # Critical parameter
38
+ trust_repo=True
39
+ ).to(device)
40
  model.eval()
41
  logger.info("Model loaded successfully")
42
  except Exception as e:
43
+ logger.error(f"Model loading failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
44
  raise
45
 
46
+ # --- Image Processing Functions ---
47
+ def process_image(image):
48
+ transform = transforms.Compose([
49
+ transforms.Resize((512, 512)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
52
+ ])
53
+ with torch.no_grad():
54
+ output = model(transform(image).unsqueeze(0).to(device))
55
+ return transforms.ToPILImage()((output * 0.5 + 0.5).squeeze().cpu())
56
 
57
+ # --- Flask Routes ---
58
  @app.route('/')
59
  def home():
60
+ return "AnimeGAN API - Türkçe"
 
 
 
 
 
 
61
 
 
62
  @app.route('/generate', methods=['POST'])
63
  def generate():
64
  start_time = datetime.now()
65
+ logger.info(f"Generate request started at {start_time}")
66
 
67
  try:
68
  if 'image' not in request.files:
 
69
  return jsonify({'error': 'No image provided'}), 400
70
 
71
  file = request.files['image']
72
+ if not file.filename:
73
+ return jsonify({'error': 'Empty filename'}), 400
74
+
 
75
  try:
 
76
  image = Image.open(io.BytesIO(file.read())).convert("RGB")
77
+ processed_img = process_image(image)
 
 
 
 
 
 
 
 
 
78
 
79
  img_io = io.BytesIO()
80
  processed_img.save(img_io, 'PNG')
81
  img_io.seek(0)
82
 
83
  duration = (datetime.now() - start_time).total_seconds()
84
+ logger.info(f"Processed in {duration:.2f} seconds")
85
  return send_file(img_io, mimetype='image/png')
86
 
87
  except Exception as e:
88
+ logger.error(f"Processing error: {str(e)}")
89
  return jsonify({'error': str(e)}), 500
90
 
91
  except Exception as e:
92
+ logger.error(f"Unexpected error: {str(e)}")
93
  return jsonify({'error': 'Internal server error'}), 500
94
 
95
+ # --- Database Routes (Keep your existing routes) ---
96
+ @app.route('/users', methods=['GET'])
97
+ def get_users():
98
+ from get_methods import get_users
99
+ return get_users()
100
+
101
+ @app.route('/user/<string:udid>', methods=['GET'])
102
+ def get_user(udid):
103
+ from get_methods import get_user
104
+ return get_user(udid)
105
 
106
+ @app.route('/add-user', methods=['POST'])
107
+ def add_user():
108
+ from post_methods import add_user
109
+ return add_user()
110
+
111
+ @app.route('/consume/<string:udid>', methods=['POST'])
112
+ def consume_user(udid):
113
+ from consume_method import consume_user
114
+ return consume_user(udid)
115
 
116
+ # --- Teardown ---
117
  @app.teardown_appcontext
118
  def shutdown_session(exception=None):
119
  close_db()