| import timm |
| from fastapi import FastAPI, File, UploadFile |
| from fastapi.responses import JSONResponse |
| from typing import Tuple |
| from io import BytesIO |
| from PIL import Image |
| import torch |
| import torchvision.transforms as transforms |
| import torch.nn as nn |
| import json |
| import numpy as np |
| import os |
| import requests |
|
|
| class MultiTaskClassifier(nn.Module): |
| def __init__(self, model_name="convnext_base", num_genres=10, num_styles=27, drop_rate=0.1, drop_path_rate=0.1): |
| super(MultiTaskClassifier, self).__init__() |
| self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, global_pool="avg") |
| in_features = self.backbone.num_features |
| self.genre_head = nn.Sequential( |
| nn.Dropout(drop_rate), |
| nn.Linear(in_features, num_genres) |
| ) |
| self.style_head = nn.Sequential( |
| nn.Dropout(drop_rate), |
| nn.Linear(in_features, num_styles) |
| ) |
|
|
| def forward(self, x): |
| features = self.backbone(x) |
| genre_logits = self.genre_head(features) |
| style_logits = self.style_head(features) |
| return genre_logits, style_logits |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = MultiTaskClassifier().to(device) |
|
|
| model_path = { |
| "genres_styles": "https://huggingface.co/Bambii-03/wikiart-genre-style-model/resolve/main/genre_style_model_weights.pth" |
| } |
|
|
| def download_from_hf(url, dest_path): |
| response = requests.get(url, stream=True) |
| response.raise_for_status() |
| with open(dest_path, "wb") as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
|
|
| def load_model(model_path): |
| for attr, url in model_path.items(): |
| path = f"{attr}.pth" |
| if not os.path.exists(path): |
| print(f"Downloading {attr} model...") |
| download_from_hf(url, path) |
| state_dict = torch.load(path, map_location=device, weights_only=False) |
| if list(state_dict.keys())[0].startswith("module."): |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
| return model |
|
|
| model = None |
|
|
| def get_model(): |
| global model |
| if model is None: |
| model = MultiTaskClassifier().to(device) |
| model = load_model(model_path) |
| return model |
|
|
| |
| def preprocess_image(image: Image.Image): |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
| ]) |
| return transform(image).unsqueeze(0) |
|
|
| |
| def load_classes(filename): |
| try: |
| with open(filename, "r") as f: |
| classes = json.load(f) |
| return {int(k): v for k, v in classes.items()} |
| except Exception as e: |
| print(f"Error loading classes from {filename}: {e}") |
| return {} |
|
|
| style_classes = load_classes("style_classes.json") |
| genre_classes = load_classes("genre_classes.json") |
|
|
| app = FastAPI(title="WikiArt Classification API", description="Predicts the style and genre of an artwork.") |
|
|
| @app.get("/") |
| async def health_check(): |
| return "FastAPI is running!" |
|
|
| @app.post("/predict-style-genre/") |
| async def predict_genre_style(file: UploadFile = File(...)): |
| try: |
| contents = await file.read() |
| image = Image.open(BytesIO(contents)).convert("RGB") |
| input_tensor = preprocess_image(image).to(device) |
| model = get_model() |
| with torch.no_grad(): |
| genre_logits, style_logits = model(input_tensor) |
| genre = torch.argmax(genre_logits, dim=1).item() |
| style = torch.argmax(style_logits, dim=1).item() |
| return { |
| "genre": genre_classes.get(genre, "Unknown"), |
| "style": style_classes.get(style, "Unknown") |
| } |
| except Exception as e: |
| return JSONResponse({"error": str(e)}, status_code=500) |
|
|
| |
| |