JaydeepR commited on
Commit
343e2be
·
verified ·
1 Parent(s): 0bbde36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from io import BytesIO
7
+ import uuid
8
+ import gc
9
+
10
+ import sys
11
+ import os
12
+
13
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
14
+ from segmentation_model import load_model,transform_image, run_inference, save_input_image, save_objects_and_metadata, extract_object
15
+ # from models.identification_model import identify_and_extract_objects
16
+ # from models.text_extraction_model import extract_text
17
+ # from models.summarization_model import summarize_text
18
+ # from utils.data_mapping import create_summary_table
19
+
20
+
21
+
22
+ model = load_model()
23
+
24
+ def resize_image(image, size=(800, 800)):
25
+ return image.resize(size, Image.ANTIALIAS)
26
+
27
+ def display_masks(outputs, image, threshold=0.5):
28
+ masks = outputs[0]['masks']
29
+ scores = outputs[0]['scores']
30
+
31
+ fig, ax = plt.subplots()
32
+ ax.imshow(np.array(image))
33
+
34
+ # extracted_objects = []
35
+
36
+ for i in range(len(scores)):
37
+ if scores[i] > threshold:
38
+ mask = masks[i].squeeze().cpu().numpy()
39
+ mask = np.where(mask > 0.5, 1, 0).astype(np.uint8)
40
+
41
+ # object_img = extract_object(image,mask)
42
+ # extracted_objects.append(object_img)
43
+ # Display the mask
44
+ ax.imshow(mask, cmap='jet', alpha=0.5) # Overlay mask on image
45
+
46
+ st.pyplot(fig)
47
+
48
+ # return extracted_objects
49
+
50
+
51
+
52
+ st.title("Image Segmentation with Mask R-CNN")
53
+
54
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
55
+
56
+ if uploaded_file is not None:
57
+ # Convert uploaded file to PIL Image
58
+ image = uploaded_file
59
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
60
+ image = Image.open(uploaded_file).convert('RGB')
61
+ # Generate a unique master ID for the image
62
+ master_id = str(uuid.uuid4())
63
+
64
+ # Save the input image
65
+ save_input_image(image, master_id)
66
+ # Transform image
67
+ image_tensor = transform_image(image)
68
+ outputs = run_inference(model, image_tensor)
69
+
70
+ # extracted_objects = display_masks(outputs, image)
71
+
72
+ # if extracted_objects:
73
+ # # Save the extracted objects and their metadata
74
+ # metadata = save_objects_and_metadata(extracted_objects, master_id)
75
+
76
+ # # Display metadata as a JSON output
77
+ # st.write("Metadata for extracted objects:")
78
+ # st.json(metadata)
79
+
80
+ # # Display each extracted object
81
+ # st.write("Extracted Objects:")
82
+ # for i, obj_img in enumerate(extracted_objects):
83
+ # st.image(obj_img, caption=f'Object {i+1}', use_column_width=True)
84
+ # else:
85
+ # st.write("No objects were detected")
86
+
87
+
88
+ # del extracted_objects
89
+ # gc.collect()
90
+
91
+ # Display results
92
+ display_masks(outputs, image)
93
+
94
+
95
+
96
+
97
+
98
+ # if uploaded_file is not None:
99
+ # image = Image.open(uploaded_file).convert("RGB")
100
+ # st.image(image, caption='Uploaded Image.', use_column_width=True)
101
+
102
+ # image_tensor = transform_image(image)
103
+ # outputs = run_inference(model, image_tensor)
104
+
105
+ # display_masks(outputs, image)
106
+
107
+
108
+
109
+ # def upload_image():
110
+ # uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
111
+ # if uploaded_file is not None:
112
+ # image = Image.open(uploaded_file)
113
+ # return image
114
+ # return None
115
+
116
+
117
+ # # def display_segmentation(image):
118
+ # # st.image(image, caption="Original Image", use_column_width=True)
119
+
120
+ # # Transform and run inference
121
+ # # image_tensor = transform_image(image)
122
+ # # outputs = run_inference(image_tensor)
123
+
124
+ # # # Save segmented objects
125
+ # # output_dir = 'segmented_objects/'
126
+ # # save_segmented_objects(image, outputs, output_dir)
127
+
128
+ # # segmented_images = [Image.open(f"{output_dir}object_{i+1}.png") for i in range(len(outputs[0]['scores']))]
129
+ # # for img in segmented_images:
130
+ # # st.image(img, caption="Segmented Object", use_column_width=True)
131
+
132
+
133
+
134
+
135
+ # def main():
136
+ # st.title("Image Processing Pipeline")
137
+
138
+ # # uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png"])
139
+ # # if uploaded_file:
140
+ # # image_path = f"data/input_images/{uploaded_file.name}"
141
+ # # image = Image.open(uploaded_file)
142
+ # # image.save(image_path) # Save the uploaded image for further processing
143
+ # # st.image(image, caption="Uploaded Image")
144
+
145
+ # # if st.button("Segment Image"):
146
+ # # segmented = segment_image(image_path)
147
+ # # st.image(segmented, caption="Segmented Image", use_column_width=True)
148
+
149
+ # # if st.button("Identify and Extract Objects"):
150
+ # # objects_data = identify_and_extract_objects(image_path)
151
+ # # extracted_objects = []
152
+
153
+ # # for obj_data in objects_data:
154
+ # # object_image = Image.open(obj_data['Image Path'])
155
+ # # text = extract_text(object_image)
156
+ # # summary = summarize_text(text)
157
+ # # obj_data['Text'] = text
158
+ # # obj_data['Summary'] = summary
159
+ # # extracted_objects.append(obj_data)
160
+
161
+ # # st.image(object_image, caption=f"Object {obj_data['ID']} - Label {obj_data['Label']}")
162
+
163
+ # # summary_file = create_summary_table(extracted_objects)
164
+ # # st.write(pd.DataFrame(extracted_objects))
165
+ # # st.download_button(label="Download Summary Table", data=open(summary_file).read(), file_name="summary.csv")
166
+
167
+ # if __name__ == "__main__":
168
+ # main()