Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| import io | |
| import base64 | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import json | |
| from flask_cors import CORS | |
| from huggingface_hub import hf_hub_download | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Define the CVAE class | |
| class CVAE(nn.Module): | |
| def __init__(self, input_shape, num_features, latent_dim): | |
| super(CVAE, self).__init__() | |
| self.input_shape = input_shape | |
| self.num_features = num_features | |
| self.latent_dim = latent_dim | |
| # Encoder | |
| self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=4, stride=2, padding=1) | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) | |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) | |
| self.fc1 = nn.Linear(128 * 27 * 22, 512) | |
| self.fc2_mu = nn.Linear(512, latent_dim) | |
| self.fc2_logvar = nn.Linear(512, latent_dim) | |
| # Decoder | |
| self.fc3 = nn.Linear(latent_dim + num_features, 512) | |
| self.fc4 = nn.Linear(512, 128 * 27 * 22) | |
| self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) | |
| self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) | |
| self.deconv3 = nn.ConvTranspose2d(32, input_shape[0], kernel_size=4, stride=2, padding=1) | |
| def encode(self, x, features): | |
| x = F.relu(self.conv1(x)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.relu(self.conv3(x)) | |
| x = x.view(x.size(0), -1) # Flatten the tensor | |
| x = F.relu(self.fc1(x)) | |
| mu = self.fc2_mu(x) | |
| logvar = self.fc2_logvar(x) | |
| return mu, logvar | |
| def reparameterize(self, mu, logvar): | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def decode(self, z, features): | |
| x = torch.cat([z, features], dim=1) | |
| x = F.relu(self.fc3(x)) | |
| x = F.relu(self.fc4(x)) | |
| x = x.view(x.size(0), 128, 27, 22) # Reshape to (batch_size, 128, 27, 22) | |
| x = F.relu(self.deconv1(x)) | |
| x = F.relu(self.deconv2(x)) | |
| x = torch.sigmoid(self.deconv3(x)) # Use sigmoid for the output layer | |
| return x | |
| def forward(self, x, features): | |
| mu, logvar = self.encode(x, features) | |
| z = self.reparameterize(mu, logvar) | |
| recon_x = self.decode(z, features) | |
| return recon_x, mu, logvar | |
| # Define image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((218, 178)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Load the trained model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_path = hf_hub_download(repo_id="tejovk311/cvae_face_editor", filename="cvae_model_entir.pth") | |
| model = torch.load(model_path, map_location=device) | |
| model.eval() # Set the model to evaluation mode | |
| model = model.to(device) # Move to GPU or CPU | |
| def process_image(image_bytes, conditions): | |
| # Load and preprocess the image | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to GPU/CPU | |
| # Convert conditions to tensor | |
| condition_tensor = torch.tensor(conditions, dtype=torch.float32).unsqueeze(0).to(device) | |
| # Perform inference | |
| with torch.no_grad(): | |
| recon_image, _, _ = model(image, condition_tensor) | |
| # Convert tensor to image | |
| recon_image = recon_image.squeeze().cpu() # Remove batch dimension and move to CPU | |
| recon_image = transforms.ToPILImage()(recon_image) | |
| # Convert the image to a base64 string | |
| buffered = io.BytesIO() | |
| recon_image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return img_str | |
| def home(): | |
| return "Flask app is running!" | |
| def process(): | |
| try: | |
| # Get the image and conditions from the request | |
| image = request.files.get('image') | |
| conditions = request.form.get('conditions') | |
| # Check if the image and conditions are provided | |
| if image is None or conditions is None: | |
| return jsonify({'error': 'Image or conditions not provided'}), 400 | |
| # Convert conditions from JSON string to list | |
| conditions = json.loads(conditions) | |
| # Process the image | |
| image_bytes = image.read() | |
| result_image_str = process_image(image_bytes, conditions) | |
| # Return the processed image as a base64 string | |
| return jsonify({'reconstructed_image': result_image_str}) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=True) | |