Spaces:
Runtime error
Runtime error
File size: 1,774 Bytes
29fe14e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import gradio as gr
from pathlib import Path
from PIL import Image
# Import your model classes (adjust import paths as needed)
from app.src.vit_load import VITDocumentClassifier
from app.src.vgg16_load import VGGDocumentClassifier
from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path
# Load models once at startup
vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path)
vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
def predict_vit(image, cut_off):
if image is None:
return "Please upload an image."
temp_path = "temp_vit_image.png"
image.save(temp_path)
result = vit_model.predict(Path(temp_path), cut_off)
return f"ViT Prediction: {result}"
def predict_vgg(image):
if image is None:
return "Please upload an image."
temp_path = "temp_vgg_image.png"
image.save(temp_path)
result = vgg_model.predict(Path(temp_path))
return f"VGG16 Prediction: {result}"
with gr.Blocks() as demo:
gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold")
with gr.Column():
result_output = gr.Textbox(label="Prediction Result", interactive=False)
with gr.Row():
vit_btn = gr.Button("Predict with ViT Model")
vgg_btn = gr.Button("Predict with VGG16 Model")
vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output)
vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output)
if __name__ == "__main__":
demo.launch(ssr_mode=False, share=True) |