import os from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.openapi.docs import get_swagger_ui_html from typing import List, Optional import io from PIL import Image import numpy as np from model import SkinClassifier # Инициализация FastAPI app = FastAPI( title="Skin Classification API", description="API для классификации кожных заболеваний (clear, acne, ros, black)", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Инициализация модели try: # Путь к модели в Hugging Face Spaces MODEL_PATH = "model/stage1_skin_classifier.pth" classifier = SkinClassifier(MODEL_PATH) print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Error loading model: {e}") classifier = None @app.get("/", response_class=HTMLResponse) async def root(): """Главная страница с информацией об API""" return "go to /docs" @app.post("/predict") async def predict( file: UploadFile = File(..., description="Изображение для классификации (JPG, PNG)"), return_image: Optional[bool] = False ): """ Классификация одного изображения - **file**: Изображение в формате JPG или PNG - **return_image**: Возвращать ли base64 изображение (по умолчанию False) """ if classifier is None: raise HTTPException(status_code=503, detail="Model not loaded") # Проверка формата файла allowed_extensions = {".jpg", ".jpeg", ".png", ".bmp"} file_ext = os.path.splitext(file.filename)[1].lower() if file_ext not in allowed_extensions: raise HTTPException( status_code=400, detail=f"Unsupported file format. Allowed: {allowed_extensions}" ) try: # Чтение файла contents = await file.read() # Предсказание result = classifier.predict(contents) response_data = { "filename": file.filename, "prediction": result["predicted_class"], "confidence": result["confidence"], "probabilities": result["all_probabilities"] } # Добавляем base64 изображение если нужно if return_image: from io import BytesIO import base64 img = Image.open(BytesIO(contents)) buffered = BytesIO() img.save(buffered, format="JPEG" if file_ext in {".jpg", ".jpeg"} else "PNG") img_str = base64.b64encode(buffered.getvalue()).decode() response_data["image_base64"] = f"data:image/jpeg;base64,{img_str}" return response_data except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") # Middleware для обработки CORS from fastapi.middleware.cors import CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], )