| import torch |
| from PIL import Image |
| import torchvision.transforms as transforms |
| import logging |
| from datetime import datetime |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class AnimeGANProcessor: |
| def __init__(self, device): |
| self.device = device |
| self.model = None |
| self.load_model() |
| |
| def load_model(self): |
| try: |
| logger.info("Loading AnimeGAN model...") |
| self.model = torch.hub.load('bryandlee/animegan2-pytorch:main', 'generator', trust_repo=True).to(self.device) |
| self.model.load_state_dict(torch.load('face_paint_512_v2.pt', map_location=self.device)) |
| self.model.eval() |
| logger.info("Model loaded successfully") |
| except Exception as e: |
| logger.error(f"Error loading model: {str(e)}") |
| raise |
| |
| def process_image(self, image): |
| try: |
| transform = transforms.Compose([ |
| transforms.Resize((512, 512)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
| with torch.no_grad(): |
| output = self.model(transform(image).unsqueeze(0).to(self.device)) |
| return transforms.ToPILImage()((output * 0.5 + 0.5).squeeze().cpu()) |
| except Exception as e: |
| logger.error(f"Error processing image: {str(e)}") |
| raise |
|
|
| def generate_anime(image_data): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| processor = AnimeGANProcessor(device) |
| |
| start_time = datetime.now() |
| logger.info(f"Generating anime image - {start_time}") |
| |
| try: |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") |
| processed_img = processor.process_image(image) |
| |
| img_io = io.BytesIO() |
| processed_img.save(img_io, 'PNG') |
| img_io.seek(0) |
| |
| duration = (datetime.now() - start_time).total_seconds() |
| logger.info(f"Successfully processed. Duration: {duration} seconds") |
| return img_io |
| except Exception as e: |
| logger.error(f"Processing error: {str(e)}", exc_info=True) |
| raise |