Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import io | |
| import joblib | |
| import torch | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import yaml | |
| import traceback | |
| import timm | |
| import logging | |
| from fastapi.logger import logger | |
| app = FastAPI() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class_mapping = {'tb': 0, 'healthy': 1, 'sick_but_no_tb': 2} | |
| reverse_mapping = {v: k for k, v in class_mapping.items()} | |
| labels = list(class_mapping.keys()) | |
| def load_model(): | |
| # config = read_params(config_path) | |
| model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3) | |
| model_state_dict = torch.load('model.pth', map_location=device) | |
| model.load_state_dict(model_state_dict) | |
| model.eval() | |
| return model | |
| def transform_image(image_bytes): | |
| my_transforms = transforms.Compose([transforms.Resize(255), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225])]) | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| return my_transforms(image).unsqueeze(0) | |
| def get_prediction(data): | |
| tensor = transform_image(data) | |
| # model = app.package['model'] | |
| with torch.no_grad(): | |
| prediction = model(tensor) | |
| prediction = reverse_mapping[prediction.argmax().item()] | |
| return prediction | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| # @app.get("/predict") | |
| # async def predict(file: UploadFile = File(...)): | |
| # """ | |
| # Perform prediction on the uploaded image | |
| # """ | |
| # logger.info('API predict called') | |
| # if not allowed_file(file.filename): | |
| # raise HTTPException(status_code=400, detail="Format not supported") | |
| # try: | |
| # img_bytes = await file.read() | |
| # class_name = get_prediction(img_bytes) | |
| # logger.info(f'Prediction: {class_name}') | |
| # return JSONResponse(content={"class_name": class_name}) | |
| # except Exception as e: | |
| # logger.error(f'Error: {str(e)}') | |
| # return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500) | |
| # # @app.get("/") | |
| # # def greet_json(): | |
| # # return {"Hello": "World!"} | |
| import torch | |
| import requests | |
| from PIL import Image | |
| from torchvision import transforms | |
| # model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval() | |
| # Download human-readable labels for ImageNet. | |
| # response = requests.get("https://git.io/JJkYN") | |
| # labels = response.text.split("\n") | |
| model = load_model() | |
| augs = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def predict(inp): | |
| # inp = transforms.Resize((224, 224))(inp).transforms.ToTensor()(inp).unsqueeze(0) | |
| inp = augs(inp).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) | |
| confidences = {labels[i]: float(prediction[i]) for i in range(3)} | |
| # prediction = reverse_mapping[prediction] | |
| return confidences | |
| import gradio as gr | |
| gr.Interface(fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3)).launch(share=True) |