Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from io import BytesIO | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| import os | |
| from .model import MalwareNet, malware_classes | |
| app = FastAPI() | |
| def preprocess_image(image_data): | |
| image = Image.open(BytesIO(image_data)).convert("RGB") | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return preprocess(image).unsqueeze(0) | |
| def load_model(): | |
| model = MalwareNet() | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| model_location = os.path.join(base_dir, '../model/malwareNet.pt') | |
| state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| # Read file bytes | |
| image_data = await file.read() | |
| # Preprocess the image | |
| img_tensor = preprocess_image(image_data) | |
| # Load the model and make the prediction | |
| model = load_model() | |
| with torch.no_grad(): | |
| prediction = model(img_tensor) | |
| # Get the predicted class | |
| predicted_class = malware_classes[torch.argmax(prediction).item()] | |
| return JSONResponse(content={"prediction": predicted_class}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing the image: {e}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "src.serve:app", | |
| host=os.environ.get("HOST", "localhost"), | |
| port=int(os.environ.get("PORT", 5000)), | |
| reload=True, | |
| ) |