VIT-css / main.py
zaman855's picture
Update main.py
5951b0c verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
import tensorflow as tf
import numpy as np
from vit_keras import vit
import tensorflow_addons as tfa
from io import BytesIO
from PIL import Image
import os
app = FastAPI()
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve the HTML page
@app.get("/", response_class=HTMLResponse)
async def serve_html():
with open("index.html", "r") as file:
html_content = file.read()
return HTMLResponse(content=html_content, status_code=200)
# Model setup
vit_model = vit.vit_b16(image_size=224, activation='softmax', pretrained=True, include_top=False, pretrained_top=False, classes=7)
model1 = tf.keras.Sequential([
vit_model,
tf.keras.layers.Flatten(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(11, activation=tfa.activations.gelu),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(7, activation='softmax')
])
model1.load_weights('vit_model_weights.h5')
labels = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
def preprocess_image(image: Image.Image):
image = image.resize((224, 224))
image = np.array(image) / 255.0
image = np.expand_dims(image, axis=0)
return image
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
image = Image.open(BytesIO(await file.read()))
processed_image = preprocess_image(image)
predictions = model1.predict(processed_image)
predicted_class = labels[np.argmax(predictions)]
confidence = np.max(predictions)
return {"predicted_class": predicted_class, "confidence": float(confidence)}