Tharuneshwar commited on
Commit
36e56cd
·
1 Parent(s): 8f260ec
Files changed (2) hide show
  1. gradio.ipynb +192 -0
  2. requirements.txt +1 -1
gradio.ipynb ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# res = segment_marker(Image.open('notebook/lion.jpg'), '[{\"flag_\":1, \"x_\": 3760.689914766355, \"y_\": 2243.232589377525}]')"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stderr",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "e:\\anaconda\\envs\\text-behind-image-env\\Lib\\site-packages\\ultralytics\\nn\\tasks.py:377: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
22
+ " return torch.load(file, map_location='cpu'), file # load\n"
23
+ ]
24
+ },
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "* Running on local URL: http://127.0.0.1:7860\n",
30
+ "* Running on public URL: https://191f0e068e6310368c.gradio.live\n",
31
+ "\n",
32
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
33
+ ]
34
+ },
35
+ {
36
+ "data": {
37
+ "text/html": [
38
+ "<div><iframe src=\"https://191f0e068e6310368c.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
39
+ ],
40
+ "text/plain": [
41
+ "<IPython.core.display.HTML object>"
42
+ ]
43
+ },
44
+ "metadata": {},
45
+ "output_type": "display_data"
46
+ },
47
+ {
48
+ "name": "stderr",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "Traceback (most recent call last):\n",
52
+ " File \"e:\\anaconda\\envs\\text-behind-image-env\\Lib\\site-packages\\gradio\\queueing.py\", line 624, in process_events\n",
53
+ " response = await route_utils.call_process_api(\n",
54
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
55
+ " File \"e:\\anaconda\\envs\\text-behind-image-env\\Lib\\site-packages\\gradio\\route_utils.py\", line 323, in call_process_api\n",
56
+ " output = await app.get_blocks().process_api(\n",
57
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
58
+ " File \"e:\\anaconda\\envs\\text-behind-image-env\\Lib\\site-packages\\gradio\\blocks.py\", line 2028, in process_api\n",
59
+ " data = await self.postprocess_data(block_fn, result[\"prediction\"], state)\n",
60
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
61
+ " File \"e:\\anaconda\\envs\\text-behind-image-env\\Lib\\site-packages\\gradio\\blocks.py\", line 1784, in postprocess_data\n",
62
+ " self.validate_outputs(block_fn, predictions) # type: ignore\n",
63
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
64
+ " File \"e:\\anaconda\\envs\\text-behind-image-env\\Lib\\site-packages\\gradio\\blocks.py\", line 1739, in validate_outputs\n",
65
+ " raise ValueError(\n",
66
+ "ValueError: A function (segment_marker) didn't return enough output values (needed: 2, returned: 1).\n",
67
+ " Output components:\n",
68
+ " [image, image]\n",
69
+ " Output values returned:\n",
70
+ " [\"Invalid marker coordinates format. Ensure it's valid JSON.\"]\n"
71
+ ]
72
+ },
73
+ {
74
+ "name": "stdout",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "Processing image with 1 marker points...\n"
78
+ ]
79
+ },
80
+ {
81
+ "name": "stderr",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "\n",
85
+ "0: 736x1024 17 objects, 3409.9ms\n",
86
+ "Speed: 127.7ms preprocess, 3409.9ms inference, 1800.1ms postprocess per image at shape (1, 3, 1024, 1024)\n"
87
+ ]
88
+ }
89
+ ],
90
+ "source": [
91
+ "import base64\n",
92
+ "from io import BytesIO\n",
93
+ "import gradio as gr\n",
94
+ "from PIL import Image\n",
95
+ "import json\n",
96
+ "\n",
97
+ "from tools.tools import convertToBuffer\n",
98
+ "from visualize.visualize import removeBgFromSegmentImage, removeOnlyBg\n",
99
+ "from models.model import getMask, loadModel\n",
100
+ "from models.preprocess import preprocess\n",
101
+ "\n",
102
+ "FAST_SAM = loadModel()\n",
103
+ "\n",
104
+ "# Helper function to convert base64 to PIL image\n",
105
+ "def base64_to_image(base64_str):\n",
106
+ " image_data = base64.b64decode(base64_str)\n",
107
+ " image = Image.open(BytesIO(image_data))\n",
108
+ " return image\n",
109
+ "\n",
110
+ "# Main processing function\n",
111
+ "def segment_marker(img_rgb: Image.Image, marker_coordinates: str):\n",
112
+ " # Parse marker coordinates from JSON string\n",
113
+ " try:\n",
114
+ " marker_coordinates = json.loads(marker_coordinates)\n",
115
+ " except json.JSONDecodeError:\n",
116
+ " return \"Invalid marker coordinates format. Ensure it's valid JSON.\"\n",
117
+ "\n",
118
+ " try:\n",
119
+ " # Process marker points and labels\n",
120
+ " input_points, input_labels = preprocess(marker_coordinates)\n",
121
+ "\n",
122
+ " print(f\"Processing image with {len(input_points)} marker points...\")\n",
123
+ " # Get mask for segmentation\n",
124
+ " masks = getMask(img_rgb, FAST_SAM, input_points, input_labels)\n",
125
+ "\n",
126
+ " # Generate the segmented images\n",
127
+ " bg_removed_segmented_img = removeBgFromSegmentImage(img_rgb, masks[0])\n",
128
+ " img_base64_bg_segmented = convertToBuffer(bg_removed_segmented_img)\n",
129
+ "\n",
130
+ " bg_only_removed_img = removeOnlyBg(img_rgb, masks[0])\n",
131
+ " img_base64_only_bg = convertToBuffer(bg_only_removed_img)\n",
132
+ "\n",
133
+ " # Convert base64 strings to PIL images for Gradio\n",
134
+ " img_bg_segmented = base64_to_image(img_base64_bg_segmented)\n",
135
+ " img_bg_only_removed = base64_to_image(img_base64_only_bg)\n",
136
+ "\n",
137
+ " return img_bg_segmented, img_bg_only_removed # Return as two separate images\n",
138
+ "\n",
139
+ " except Exception as e:\n",
140
+ " print(f\"An error occurred: {str(e)}\")\n",
141
+ " return \"An error occurred while processing the image.\", None\n",
142
+ "\n",
143
+ "# Set up the Gradio interface\n",
144
+ "iface = gr.Interface(\n",
145
+ " fn=segment_marker,\n",
146
+ " inputs=[\n",
147
+ " gr.Image(type=\"pil\", label=\"Upload Image\"),\n",
148
+ " gr.Textbox(label=\"Markers Coordinates (JSON format)\")\n",
149
+ " ],\n",
150
+ " outputs=[\n",
151
+ " gr.Image(type=\"pil\", label=\"Background Removed with Segmentation\"),\n",
152
+ " gr.Image(type=\"pil\", label=\"Only Background Removed\")\n",
153
+ " ],\n",
154
+ " title=\"Image Segmentation with Background Removal\",\n",
155
+ " description=\"Upload an image and JSON-formatted marker coordinates to perform image segmentation and background removal.\"\n",
156
+ ")\n",
157
+ "\n",
158
+ "# Run the Gradio app\n",
159
+ "if __name__ == \"__main__\":\n",
160
+ " iface.launch(share=True)\n"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": []
169
+ }
170
+ ],
171
+ "metadata": {
172
+ "kernelspec": {
173
+ "display_name": "text-behind-image-env",
174
+ "language": "python",
175
+ "name": "python3"
176
+ },
177
+ "language_info": {
178
+ "codemirror_mode": {
179
+ "name": "ipython",
180
+ "version": 3
181
+ },
182
+ "file_extension": ".py",
183
+ "mimetype": "text/x-python",
184
+ "name": "python",
185
+ "nbconvert_exporter": "python",
186
+ "pygments_lexer": "ipython3",
187
+ "version": "3.11.10"
188
+ }
189
+ },
190
+ "nbformat": 4,
191
+ "nbformat_minor": 2
192
+ }
requirements.txt CHANGED
@@ -11,5 +11,5 @@ git+https://github.com/openai/CLIP.git
11
  # uvicorn
12
  # serverless_wsgi
13
  # gunicorn
14
- git+https://github.com/CASIA-IVA-Lab/FastSAM.git
15
  ultralytics==8.0.100
 
11
  # uvicorn
12
  # serverless_wsgi
13
  # gunicorn
14
+ git+https://github.com/CASIA-IVA-Lab/FastSAM.git@v0.0.2
15
  ultralytics==8.0.100