DemoTest / app.py
Kingoteam's picture
Update app.py
3f4a0b9 verified
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import StreamingResponse
from PIL import Image
import torch
from io import BytesIO
import gradio as gr
import logging
# تنظیم لاگ برای دیباگ
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# کش برای ذخیره مدل‌ها
model_cache = {}
def load_model(style_name):
logger.info(f"Loading model for style: {style_name}")
if style_name not in model_cache:
try:
model = torch.hub.load(
"bryandlee/animegan2-pytorch:main",
"generator",
pretrained=style_name,
verbose=False
).eval()
model_cache[style_name] = model
logger.info(f"Model {style_name} loaded successfully")
except Exception as e:
logger.error(f"Error loading model {style_name}: {str(e)}")
raise
return model_cache[style_name]
def animegan2_transform(input_img, style_name):
logger.info("Starting image transformation")
try:
if isinstance(input_img, str):
input_img = Image.open(BytesIO(requests.get(input_img).content)).convert("RGB")
elif isinstance(input_img, Image.Image):
input_img = input_img.convert("RGB")
else:
raise ValueError("فرمت تصویر ورودی صحیح نیست!")
# افزایش رزولوشن به 760x760
input_img = input_img.resize((760, 760))
logger.info("Image resized")
model = load_model(style_name)
face2paint_func = torch.hub.load(
"bryandlee/animegan2-pytorch:main",
"face2paint",
size=760, # افزایش رزولوشن خروجی
verbose=False
)
output_img = face2paint_func(model, input_img)
logger.info("Image transformation completed")
return output_img
except Exception as e:
logger.error(f"Error in transformation: {str(e)}")
raise
@app.get("/")
async def root():
return {"message": "Welcome to AnimeGANv2 API. Use POST /animegan for image processing."}
@app.post("/animegan")
async def process_image(file: UploadFile = File(...), style: str = Form(...)):
logger.info("Received API request")
try:
image = Image.open(BytesIO(await file.read())).convert("RGB")
output_img = animegan2_transform(image, style)
output_buffer = BytesIO()
output_img.save(output_buffer, format="PNG")
output_buffer.seek(0)
logger.info("Image processed successfully")
return StreamingResponse(
output_buffer,
media_type="image/png",
headers={"Content-Disposition": "inline; filename=output.png"}
)
except Exception as e:
logger.error(f"API error: {str(e)}")
return {"error": str(e)}
# رابط Gradio (اختیاری)
iface = gr.Interface(
fn=animegan2_transform,
inputs=[
gr.Image(type="pil", label="آپلود تصویر یا وارد کردن URL"),
gr.Dropdown(
["face_paint_512_v1", "face_paint_512_v2", "paprika", "celeba_distill"],
value="face_paint_512_v2",
label="انتخاب استایل"
)
],
outputs=gr.Image(type="pil", label="تصویر انیمه"),
title="AnimeGANv2 - تبدیل تصویر به انیمه با انتخاب استایل"
)
if __name__ == "__main__":
import uvicorn
logger.info("Starting Uvicorn server")
uvicorn.run(app, host="0.0.0.0", port=7860)