File size: 3,621 Bytes
aa3fb8d
a6ec3f5
fdfe46e
07d79b9
fdfe46e
aa3fb8d
65e90c2
 
 
 
 
5c1653b
aa3fb8d
fdfe46e
65e90c2
 
 
8421d10
65e90c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdfe46e
2209182
65e90c2
 
 
 
 
 
 
 
 
3f4a0b9
 
65e90c2
 
 
 
 
 
3f4a0b9
65e90c2
 
 
 
 
 
 
 
2209182
a6ec3f5
 
 
 
aa3fb8d
 
65e90c2
 
 
 
 
 
 
 
a6ec3f5
 
 
 
 
65e90c2
 
 
 
 
aa3fb8d
 
 
 
 
 
 
 
 
 
65e90c2
aa3fb8d
 
3b12278
2209182
aa3fb8d
65e90c2
a6ec3f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from PIL import Image
import torch
from io import BytesIO
import gradio as gr
import logging

# تنظیم لاگ برای دیباگ
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# کش برای ذخیره مدل‌ها
model_cache = {}

def load_model(style_name):
    logger.info(f"Loading model for style: {style_name}")
    if style_name not in model_cache:
        try:
            model = torch.hub.load(
                "bryandlee/animegan2-pytorch:main",
                "generator",
                pretrained=style_name,
                verbose=False
            ).eval()
            model_cache[style_name] = model
            logger.info(f"Model {style_name} loaded successfully")
        except Exception as e:
            logger.error(f"Error loading model {style_name}: {str(e)}")
            raise
    return model_cache[style_name]

def animegan2_transform(input_img, style_name):
    logger.info("Starting image transformation")
    try:
        if isinstance(input_img, str):
            input_img = Image.open(BytesIO(requests.get(input_img).content)).convert("RGB")
        elif isinstance(input_img, Image.Image):
            input_img = input_img.convert("RGB")
        else:
            raise ValueError("فرمت تصویر ورودی صحیح نیست!")
        
        # افزایش رزولوشن به 760x760
        input_img = input_img.resize((760, 760))
        logger.info("Image resized")

        model = load_model(style_name)
        face2paint_func = torch.hub.load(
            "bryandlee/animegan2-pytorch:main",
            "face2paint",
            size=760,  # افزایش رزولوشن خروجی
            verbose=False
        )
        output_img = face2paint_func(model, input_img)
        logger.info("Image transformation completed")
        return output_img
    except Exception as e:
        logger.error(f"Error in transformation: {str(e)}")
        raise

@app.get("/")
async def root():
    return {"message": "Welcome to AnimeGANv2 API. Use POST /animegan for image processing."}

@app.post("/animegan")
async def process_image(file: UploadFile = File(...), style: str = Form(...)):
    logger.info("Received API request")
    try:
        image = Image.open(BytesIO(await file.read())).convert("RGB")
        output_img = animegan2_transform(image, style)
        output_buffer = BytesIO()
        output_img.save(output_buffer, format="PNG")
        output_buffer.seek(0)
        logger.info("Image processed successfully")
        return StreamingResponse(
            output_buffer,
            media_type="image/png",
            headers={"Content-Disposition": "inline; filename=output.png"}
        )
    except Exception as e:
        logger.error(f"API error: {str(e)}")
        return {"error": str(e)}

# رابط Gradio (اختیاری)
iface = gr.Interface(
    fn=animegan2_transform,
    inputs=[
        gr.Image(type="pil", label="آپلود تصویر یا وارد کردن URL"),
        gr.Dropdown(
            ["face_paint_512_v1", "face_paint_512_v2", "paprika", "celeba_distill"],
            value="face_paint_512_v2",
            label="انتخاب استایل"
        )
    ],
    outputs=gr.Image(type="pil", label="تصویر انیمه"),
    title="AnimeGANv2 - تبدیل تصویر به انیمه با انتخاب استایل"
)

if __name__ == "__main__":
    import uvicorn
    logger.info("Starting Uvicorn server")
    uvicorn.run(app, host="0.0.0.0", port=7860)