Vit-2 / main.py
zaman855's picture
Update main.py
f349e38 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import tensorflow as tf
import numpy as np
from vit_keras import vit
import tensorflow_addons as tfa
from io import BytesIO
from PIL import Image
app = FastAPI()
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this according to your security needs
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve the frontend directory as static files
app.mount("/", StaticFiles(directory=".", html=True), name="static")
# Load the model
vit_model = vit.vit_b16(image_size=224, activation='softmax', pretrained=True, include_top=False, pretrained_top=False, classes=7)
model1 = tf.keras.Sequential([
vit_model,
tf.keras.layers.Flatten(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(11, activation=tfa.activations.gelu),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(7, activation='softmax')
])
model1.load_weights('vit_model_weights.h5')
labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
def preprocess_image(image: Image.Image):
image = image.resize((224, 224))
image = np.array(image) / 255.0
image = np.expand_dims(image, axis=0)
return image
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
image = Image.open(BytesIO(await file.read()))
processed_image = preprocess_image(image)
predictions = model1.predict(processed_image)
predicted_class = labels[np.argmax(predictions)]
confidence = np.max(predictions)
return {"predicted_class": predicted_class, "confidence": float(confidence)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)