File size: 1,774 Bytes
29fe14e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
from pathlib import Path
from PIL import Image

# Import your model classes (adjust import paths as needed)
from app.src.vit_load import VITDocumentClassifier
from app.src.vgg16_load import VGGDocumentClassifier
from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path

# Load models once at startup
vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path)
vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)

def predict_vit(image, cut_off):
    if image is None:
        return "Please upload an image."
    temp_path = "temp_vit_image.png"
    image.save(temp_path)
    result = vit_model.predict(Path(temp_path), cut_off)
    return f"ViT Prediction: {result}"

def predict_vgg(image):
    if image is None:
        return "Please upload an image."
    temp_path = "temp_vgg_image.png"
    image.save(temp_path)
    result = vgg_model.predict(Path(temp_path))
    return f"VGG16 Prediction: {result}"

with gr.Blocks() as demo:
    gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold")
        with gr.Column():
            result_output = gr.Textbox(label="Prediction Result", interactive=False)
    with gr.Row():
        vit_btn = gr.Button("Predict with ViT Model")
        vgg_btn = gr.Button("Predict with VGG16 Model")

    vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output)
    vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output)

if __name__ == "__main__":
    demo.launch(ssr_mode=False, share=True)