younginpiniti commited on
Commit
8adf75f
·
1 Parent(s): fa0001a

feat: 의존성 목록을 갱신하고 애플리케이션 로직을 수정했습니다.

Browse files
Files changed (2) hide show
  1. app.py +142 -9
  2. requirements.txt +5 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  스테이블 디퓨전 WebUI - 허깅페이스 스페이스용
3
- Gradio 인터페이스를 통한 이미지 생성 (애니메이션 모델 지원)
4
  """
5
 
6
  import gradio as gr
@@ -9,6 +9,11 @@ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
9
  from PIL import Image
10
  import os
11
  import gc
 
 
 
 
 
12
 
13
  # 사용 가능한 모델 목록
14
  MODELS = {
@@ -374,16 +379,144 @@ def create_interface():
374
 
375
  return demo
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  # 메인 실행
 
378
  if __name__ == "__main__":
379
- print("🌸 Anime Diffusion WebUI 시작...")
380
 
381
- # Gradio 앱 생성 및 실행
382
  demo = create_interface()
383
 
384
- # 허깅페이스 스페이스서는 포트 7860 사용
385
- demo.launch(
386
- server_name="0.0.0.0",
387
- server_port=7860,
388
- share=False
389
- )
 
 
 
1
  """
2
  스테이블 디퓨전 WebUI - 허깅페이스 스페이스용
3
+ Gradio 인터페이스 + REST API를 통한 이미지 생성 (애니메이션 모델 지원)
4
  """
5
 
6
  import gradio as gr
 
9
  from PIL import Image
10
  import os
11
  import gc
12
+ import io
13
+ import base64
14
+ from typing import Optional
15
+ from fastapi import FastAPI, HTTPException
16
+ from pydantic import BaseModel, Field
17
 
18
  # 사용 가능한 모델 목록
19
  MODELS = {
 
379
 
380
  return demo
381
 
382
+ # ================================
383
+ # REST API 엔드포인트 정의
384
+ # ================================
385
+
386
+ # API 요청/응답 모델 정의
387
+ class GenerateRequest(BaseModel):
388
+ """이미지 생성 요청 모델"""
389
+ prompt: str = Field(..., description="이미지 생성 프롬프트")
390
+ model_name: str = Field(
391
+ default="🎨 Mistoon Anime V3 (카툰풍 애니메이션)",
392
+ description="사용할 모델 이름"
393
+ )
394
+ negative_prompt: str = Field(default="", description="네거티브 프롬프트")
395
+ num_inference_steps: int = Field(default=25, ge=10, le=50, description="추론 스텝 수")
396
+ guidance_scale: float = Field(default=7.5, ge=1.0, le=15.0, description="CFG 스케일")
397
+ width: int = Field(default=512, ge=256, le=768, description="이미지 너비")
398
+ height: int = Field(default=512, ge=256, le=768, description="이미지 높이")
399
+ seed: int = Field(default=-1, description="시드 값 (-1이면 랜덤)")
400
+
401
+ class GenerateResponse(BaseModel):
402
+ """이미지 생성 응답 모델"""
403
+ success: bool
404
+ message: str
405
+ image_base64: Optional[str] = None
406
+ seed: Optional[int] = None
407
+
408
+ class ModelsResponse(BaseModel):
409
+ """모델 목록 응답"""
410
+ models: list[str]
411
+
412
+ # FastAPI 앱 생성
413
+ api_app = FastAPI(
414
+ title="Anime Diffusion API",
415
+ description="애니메이션 스타일 이미지 생성 REST API",
416
+ version="1.0.0"
417
+ )
418
+
419
+ @api_app.get("/api/models", response_model=ModelsResponse)
420
+ async def get_models():
421
+ """사용 가능한 모델 목록 조회"""
422
+ return ModelsResponse(models=list(MODELS.keys()))
423
+
424
+ @api_app.post("/api/generate", response_model=GenerateResponse)
425
+ async def api_generate_image(request: GenerateRequest):
426
+ """
427
+ 이미지 생성 API
428
+
429
+ 프롬프트를 전달하면 Base64로 인코딩된 이미지를 반환합니다.
430
+ """
431
+ global pipe
432
+
433
+ if not request.prompt.strip():
434
+ raise HTTPException(status_code=400, detail="프롬프트를 입력해주세요")
435
+
436
+ if request.model_name not in MODELS:
437
+ raise HTTPException(
438
+ status_code=400,
439
+ detail=f"알 수 없는 모델입니다. 사용 가능한 모델: {list(MODELS.keys())}"
440
+ )
441
+
442
+ # 모델 로드
443
+ pipe, status = load_model(request.model_name)
444
+ if pipe is None:
445
+ raise HTTPException(status_code=500, detail=status)
446
+
447
+ # 네거티브 프롬프트 설정
448
+ default_negative = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
449
+
450
+ if request.negative_prompt.strip():
451
+ full_negative = f"{request.negative_prompt}, {default_negative}"
452
+ else:
453
+ full_negative = default_negative
454
+
455
+ # 시드 설정
456
+ seed = request.seed
457
+ if seed == -1:
458
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
459
+
460
+ generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
461
+
462
+ try:
463
+ print(f"🎨 [API] 이미지 생성 중... 프롬프트: {request.prompt[:50]}...")
464
+
465
+ # 이미지 생성
466
+ result = pipe(
467
+ prompt=request.prompt,
468
+ negative_prompt=full_negative,
469
+ num_inference_steps=request.num_inference_steps,
470
+ guidance_scale=request.guidance_scale,
471
+ width=request.width,
472
+ height=request.height,
473
+ generator=generator
474
+ )
475
+
476
+ image = result.images[0]
477
+
478
+ # 이미지를 Base64로 인코딩
479
+ buffer = io.BytesIO()
480
+ image.save(buffer, format="PNG")
481
+ buffer.seek(0)
482
+ image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
483
+
484
+ print(f"✅ [API] 이미지 생성 완료! (시드: {seed})")
485
+
486
+ return GenerateResponse(
487
+ success=True,
488
+ message="이미지 생성 완료",
489
+ image_base64=image_base64,
490
+ seed=seed
491
+ )
492
+
493
+ except Exception as e:
494
+ print(f"❌ [API] 이미지 생성 실패: {e}")
495
+ raise HTTPException(status_code=500, detail=str(e))
496
+
497
+ @api_app.get("/api/health")
498
+ async def health_check():
499
+ """서버 상태 확인"""
500
+ return {
501
+ "status": "healthy",
502
+ "device": DEVICE,
503
+ "model_loaded": current_model_id is not None
504
+ }
505
+
506
+ # ================================
507
  # 메인 실행
508
+ # ================================
509
  if __name__ == "__main__":
510
+ print("🌸 Anime Diffusion WebUI + API 시작...")
511
 
512
+ # Gradio 앱 생성
513
  demo = create_interface()
514
 
515
+ # FastAPIGradio 마운트
516
+ app = gr.mount_gradio_app(api_app, demo, path="/")
517
+
518
+ # uvicorn으로 통합 서버 실행
519
+ import uvicorn
520
+ print("📡 API 문서: http://localhost:7860/docs")
521
+ print("🌐 웹 UI: http://localhost:7860/")
522
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -15,3 +15,8 @@ numpy
15
 
16
  # 웹 인터페이스
17
  gradio
 
 
 
 
 
 
15
 
16
  # 웹 인터페이스
17
  gradio
18
+
19
+ # REST API
20
+ fastapi
21
+ uvicorn
22
+ python-multipart