yusufgundogdu commited on
Commit
5bcb990
·
verified ·
1 Parent(s): 2b22f90

Create animegan_method

Browse files
Files changed (1) hide show
  1. animegan_method +60 -0
animegan_method ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+ import logging
5
+ from datetime import datetime
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class AnimeGANProcessor:
10
+ def __init__(self, device):
11
+ self.device = device
12
+ self.model = None
13
+ self.load_model()
14
+
15
+ def load_model(self):
16
+ try:
17
+ logger.info("Loading AnimeGAN model...")
18
+ self.model = torch.hub.load('bryandlee/animegan2-pytorch:main', 'generator', trust_repo=True).to(self.device)
19
+ self.model.load_state_dict(torch.load('face_paint_512_v2.pt', map_location=self.device))
20
+ self.model.eval()
21
+ logger.info("Model loaded successfully")
22
+ except Exception as e:
23
+ logger.error(f"Error loading model: {str(e)}")
24
+ raise
25
+
26
+ def process_image(self, image):
27
+ try:
28
+ transform = transforms.Compose([
29
+ transforms.Resize((512, 512)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
32
+ ])
33
+ with torch.no_grad():
34
+ output = self.model(transform(image).unsqueeze(0).to(self.device))
35
+ return transforms.ToPILImage()((output * 0.5 + 0.5).squeeze().cpu())
36
+ except Exception as e:
37
+ logger.error(f"Error processing image: {str(e)}")
38
+ raise
39
+
40
+ def generate_anime(image_data):
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ processor = AnimeGANProcessor(device)
43
+
44
+ start_time = datetime.now()
45
+ logger.info(f"Generating anime image - {start_time}")
46
+
47
+ try:
48
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
49
+ processed_img = processor.process_image(image)
50
+
51
+ img_io = io.BytesIO()
52
+ processed_img.save(img_io, 'PNG')
53
+ img_io.seek(0)
54
+
55
+ duration = (datetime.now() - start_time).total_seconds()
56
+ logger.info(f"Successfully processed. Duration: {duration} seconds")
57
+ return img_io
58
+ except Exception as e:
59
+ logger.error(f"Processing error: {str(e)}", exc_info=True)
60
+ raise