KaushiGihan commited on
Commit
29fe14e
·
verified ·
1 Parent(s): a517aad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -46
app.py CHANGED
@@ -1,46 +1,46 @@
1
- import gradio as gr
2
- from pathlib import Path
3
- from PIL import Image
4
-
5
- # Import your model classes (adjust import paths as needed)
6
- from app.src.vit_load import VITDocumentClassifier
7
- from app.src.vgg16_load import VGGDocumentClassifier
8
- from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path
9
-
10
- # Load models once at startup
11
- vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path)
12
- vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
13
-
14
- def predict_vit(image, cut_off):
15
- if image is None:
16
- return "Please upload an image."
17
- temp_path = "temp_vit_image.png"
18
- image.save(temp_path)
19
- result = vit_model.predict(Path(temp_path), cut_off)
20
- return f"ViT Prediction: {result}"
21
-
22
- def predict_vgg(image):
23
- if image is None:
24
- return "Please upload an image."
25
- temp_path = "temp_vgg_image.png"
26
- image.save(temp_path)
27
- result = vgg_model.predict(Path(temp_path))
28
- return f"VGG16 Prediction: {result}"
29
-
30
- with gr.Blocks() as demo:
31
- gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.")
32
- with gr.Row():
33
- with gr.Column():
34
- image_input = gr.Image(type="pil", label="Upload Image")
35
- cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold")
36
- with gr.Column():
37
- result_output = gr.Textbox(label="Prediction Result", interactive=False)
38
- with gr.Row():
39
- vit_btn = gr.Button("Predict with ViT Model")
40
- vgg_btn = gr.Button("Predict with VGG16 Model")
41
-
42
- vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output)
43
- vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output)
44
-
45
- if __name__ == "__main__":
46
- demo.launch()
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ from PIL import Image
4
+
5
+ # Import your model classes (adjust import paths as needed)
6
+ from app.src.vit_load import VITDocumentClassifier
7
+ from app.src.vgg16_load import VGGDocumentClassifier
8
+ from app.src.constant import vit_model_path, vit_mlb_path, vgg_model_path, vgg_mlb_path
9
+
10
+ # Load models once at startup
11
+ vit_model = VITDocumentClassifier(vit_model_path, vit_mlb_path)
12
+ vgg_model = VGGDocumentClassifier(vgg_model_path, vgg_mlb_path)
13
+
14
+ def predict_vit(image, cut_off):
15
+ if image is None:
16
+ return "Please upload an image."
17
+ temp_path = "temp_vit_image.png"
18
+ image.save(temp_path)
19
+ result = vit_model.predict(Path(temp_path), cut_off)
20
+ return f"ViT Prediction: {result}"
21
+
22
+ def predict_vgg(image):
23
+ if image is None:
24
+ return "Please upload an image."
25
+ temp_path = "temp_vgg_image.png"
26
+ image.save(temp_path)
27
+ result = vgg_model.predict(Path(temp_path))
28
+ return f"VGG16 Prediction: {result}"
29
+
30
+ with gr.Blocks() as demo:
31
+ gr.Markdown("# Document Classification Demo\nUpload an image and choose a model to classify it.")
32
+ with gr.Row():
33
+ with gr.Column():
34
+ image_input = gr.Image(type="pil", label="Upload Image")
35
+ cut_off = gr.Slider(0, 1, value=0.5, label="ViT Cutoff Threshold")
36
+ with gr.Column():
37
+ result_output = gr.Textbox(label="Prediction Result", interactive=False)
38
+ with gr.Row():
39
+ vit_btn = gr.Button("Predict with ViT Model")
40
+ vgg_btn = gr.Button("Predict with VGG16 Model")
41
+
42
+ vit_btn.click(fn=predict_vit, inputs=[image_input, cut_off], outputs=result_output)
43
+ vgg_btn.click(fn=predict_vgg, inputs=image_input, outputs=result_output)
44
+
45
+ if __name__ == "__main__":
46
+ demo.launch(ssr_mode=False, share=True)