Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from pathlib import Path | |
| from PIL import Image | |
| # Import your model classes (adjust import paths as needed) | |
| from app.src.vit_load import VITDocumentClassifier | |
| from app.src.vgg16_load import VGGDocumentClassifier | |
| from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path | |
| # Load models once at startup | |
| vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path) | |
| vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path) | |
| def predict_vit(image, cut_off): | |
| if image is None: | |
| return "Please upload an image." | |
| temp_path = "temp_vit_image.png" | |
| image.save(temp_path) | |
| result = vit_model.predict(Path(temp_path), cut_off) | |
| return f"ViT Prediction: {result}" | |
| def predict_vgg(image): | |
| if image is None: | |
| return "Please upload an image." | |
| temp_path = "temp_vgg_image.png" | |
| image.save(temp_path) | |
| result = vgg_model.predict(Path(temp_path)) | |
| return f"VGG16 Prediction: {result}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold") | |
| with gr.Column(): | |
| result_output = gr.Textbox(label="Prediction Result", interactive=False) | |
| with gr.Row(): | |
| vit_btn = gr.Button("Predict with ViT Model") | |
| vgg_btn = gr.Button("Predict with VGG16 Model") | |
| vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output) | |
| vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False, share=True) |