Spaces:
Runtime error
Runtime error
| import base64 | |
| import io | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import torch # Or tensorflow, etc. - Import necessary libraries | |
| # Import your specific AnimeGAN model loading and processing functions | |
| # This is HIGHLY DEPENDENT on the model you choose. | |
| # Example Placeholder - Replace with actual model loading/inference | |
| # from your_animegan_module import load_animegan_model, run_inference | |
| # --- Configuration --- | |
| # Set device (use GPU if available in your HF Space hardware) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| # --- Model Loading --- | |
| # Load your model ONCE when the application starts. | |
| # This is a placeholder - replace with your actual model loading logic. | |
| try: | |
| # Example: animegan_model = load_animegan_model('path/to/weights_or_hub_id').to(DEVICE) | |
| animegan_model = None # Replace this line | |
| # Make sure model is in evaluation mode if applicable (e.g., PyTorch) | |
| # if animegan_model: | |
| # animegan_model.eval() | |
| print("AnimeGAN model placeholder loaded (Replace with actual model!).") | |
| # Dummy model for testing structure if real model fails initially | |
| # class DummyModel: | |
| # def __call__(self, img): | |
| # print("Dummy model processing...") | |
| # # Simple grayscale effect as placeholder | |
| # return img.convert("L").convert("RGB") | |
| # animegan_model = DummyModel() | |
| except Exception as e: | |
| print(f"Error loading AnimeGAN model: {e}") | |
| animegan_model = None # Ensure it's None if loading failed | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| # Define the input data model (expects a base64 encoded image string) | |
| class ImageData(BaseModel): | |
| image_base64: str | |
| # Define the output data model | |
| class ResultData(BaseModel): | |
| result_base64: str | |
| def read_root(): | |
| return {"message": "AnimeGAN FastAPI Backend is running."} | |
| async def process_image(data: ImageData): | |
| if not animegan_model: | |
| raise HTTPException(status_code=503, detail="AnimeGAN model is not loaded or failed to load.") | |
| try: | |
| # 1. Decode Base64 Image | |
| try: | |
| image_bytes = base64.b64decode(data.image_base64) | |
| input_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception as e: | |
| print(f"Error decoding base64 image: {e}") | |
| raise HTTPException(status_code=400, detail=f"Invalid base64 image data: {e}") | |
| print(f"Received image: {input_image.size}, format: {input_image.format}") | |
| # 2. Preprocess (if required by the model) | |
| # Example: transform = transforms.Compose([...]) | |
| # processed_input = transform(input_image).unsqueeze(0).to(DEVICE) | |
| # This depends heavily on your chosen model! | |
| # For now, we assume the model takes a PIL image directly. | |
| processed_input = input_image | |
| # 3. Run Inference | |
| try: | |
| print("Running AnimeGAN inference...") | |
| # Add torch.no_grad() if using PyTorch to save memory | |
| # with torch.no_grad(): | |
| # output_tensor = animegan_model(processed_input) | |
| # Placeholder call: | |
| output_image_pil = animegan_model(processed_input) # Adjust based on model input/output | |
| # If model returns a tensor, convert it back to PIL | |
| # Example: | |
| # output_image_pil = transforms.ToPILImage()(output_tensor.squeeze().cpu()) | |
| if not isinstance(output_image_pil, Image.Image): | |
| raise TypeError(f"Model output was not a PIL Image, but {type(output_image_pil)}") | |
| print(f"Inference complete. Output size: {output_image_pil.size}") | |
| except Exception as e: | |
| print(f"Error during model inference: {e}") | |
| raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") | |
| # 4. Convert Result to Base64 | |
| buffer = io.BytesIO() | |
| # Save as PNG to preserve quality, or JPEG for smaller size | |
| output_image_pil.save(buffer, format="PNG") | |
| result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| return {"result_base64": result_base64} | |
| except HTTPException as http_exc: | |
| # Re-raise FastAPI's HTTP exceptions | |
| raise http_exc | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}") | |
| # Optional: Add more endpoints or health checks if needed | |
| def health_check(): | |
| return {"status": "ok", "model_loaded": animegan_model is not None} |