|
|
from fastapi import FastAPI, File, UploadFile, Form |
|
|
from fastapi.responses import StreamingResponse |
|
|
from PIL import Image |
|
|
import torch |
|
|
from io import BytesIO |
|
|
import gradio as gr |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
model_cache = {} |
|
|
|
|
|
def load_model(style_name): |
|
|
logger.info(f"Loading model for style: {style_name}") |
|
|
if style_name not in model_cache: |
|
|
try: |
|
|
model = torch.hub.load( |
|
|
"bryandlee/animegan2-pytorch:main", |
|
|
"generator", |
|
|
pretrained=style_name, |
|
|
verbose=False |
|
|
).eval() |
|
|
model_cache[style_name] = model |
|
|
logger.info(f"Model {style_name} loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model {style_name}: {str(e)}") |
|
|
raise |
|
|
return model_cache[style_name] |
|
|
|
|
|
def animegan2_transform(input_img, style_name): |
|
|
logger.info("Starting image transformation") |
|
|
try: |
|
|
if isinstance(input_img, str): |
|
|
input_img = Image.open(BytesIO(requests.get(input_img).content)).convert("RGB") |
|
|
elif isinstance(input_img, Image.Image): |
|
|
input_img = input_img.convert("RGB") |
|
|
else: |
|
|
raise ValueError("فرمت تصویر ورودی صحیح نیست!") |
|
|
|
|
|
|
|
|
input_img = input_img.resize((760, 760)) |
|
|
logger.info("Image resized") |
|
|
|
|
|
model = load_model(style_name) |
|
|
face2paint_func = torch.hub.load( |
|
|
"bryandlee/animegan2-pytorch:main", |
|
|
"face2paint", |
|
|
size=760, |
|
|
verbose=False |
|
|
) |
|
|
output_img = face2paint_func(model, input_img) |
|
|
logger.info("Image transformation completed") |
|
|
return output_img |
|
|
except Exception as e: |
|
|
logger.error(f"Error in transformation: {str(e)}") |
|
|
raise |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Welcome to AnimeGANv2 API. Use POST /animegan for image processing."} |
|
|
|
|
|
@app.post("/animegan") |
|
|
async def process_image(file: UploadFile = File(...), style: str = Form(...)): |
|
|
logger.info("Received API request") |
|
|
try: |
|
|
image = Image.open(BytesIO(await file.read())).convert("RGB") |
|
|
output_img = animegan2_transform(image, style) |
|
|
output_buffer = BytesIO() |
|
|
output_img.save(output_buffer, format="PNG") |
|
|
output_buffer.seek(0) |
|
|
logger.info("Image processed successfully") |
|
|
return StreamingResponse( |
|
|
output_buffer, |
|
|
media_type="image/png", |
|
|
headers={"Content-Disposition": "inline; filename=output.png"} |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"API error: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=animegan2_transform, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="آپلود تصویر یا وارد کردن URL"), |
|
|
gr.Dropdown( |
|
|
["face_paint_512_v1", "face_paint_512_v2", "paprika", "celeba_distill"], |
|
|
value="face_paint_512_v2", |
|
|
label="انتخاب استایل" |
|
|
) |
|
|
], |
|
|
outputs=gr.Image(type="pil", label="تصویر انیمه"), |
|
|
title="AnimeGANv2 - تبدیل تصویر به انیمه با انتخاب استایل" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
logger.info("Starting Uvicorn server") |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |