Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ from torch.autograd import Variable
|
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
import streamlit as st
|
| 10 |
|
| 11 |
-
# Constants and environment variables
|
| 12 |
TOKEN = os.getenv('hf_read_token')
|
| 13 |
repo_id = "Nischay103/captcha_recognition"
|
| 14 |
model_files = {
|
|
@@ -42,13 +41,11 @@ char_sets = {
|
|
| 42 |
"11": "0123456789$"
|
| 43 |
}
|
| 44 |
|
| 45 |
-
# Load models
|
| 46 |
models = {}
|
| 47 |
for key, model_file in model_files.items():
|
| 48 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_file, token=TOKEN)
|
| 49 |
models[key] = torch.jit.load(model_path)
|
| 50 |
|
| 51 |
-
# Function to transform image
|
| 52 |
def transform_image(image_path):
|
| 53 |
transform = T.Compose([T.ToTensor()])
|
| 54 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
@@ -60,7 +57,6 @@ def transform_image(image_path):
|
|
| 60 |
image = image.unsqueeze(1)
|
| 61 |
return image
|
| 62 |
|
| 63 |
-
# Function to get label from model prediction
|
| 64 |
def get_label(model_prediction, model_version):
|
| 65 |
max_captcha_len, cls_dim_encoded = len_dim_pair[model_version]
|
| 66 |
_cls = char_sets[cls_dim_encoded]
|
|
@@ -73,22 +69,20 @@ def get_label(model_prediction, model_version):
|
|
| 73 |
lab += get_char
|
| 74 |
return lab
|
| 75 |
|
| 76 |
-
|
| 77 |
-
st.
|
| 78 |
-
st.write("Recognize captchas using different models")
|
| 79 |
|
| 80 |
-
uploaded_file = st.file_uploader("
|
| 81 |
-
model_version = st.selectbox("
|
| 82 |
|
| 83 |
if uploaded_file is not None:
|
| 84 |
-
# Save uploaded file to a temporary location
|
| 85 |
with open("temp_captcha_image.png", "wb") as f:
|
| 86 |
f.write(uploaded_file.getbuffer())
|
| 87 |
|
| 88 |
input_image_path = "temp_captcha_image.png"
|
| 89 |
-
st.image(input_image_path, caption='
|
| 90 |
|
| 91 |
-
if st.button('
|
| 92 |
input = transform_image(input_image_path)
|
| 93 |
model = models[model_version]
|
| 94 |
with torch.no_grad():
|
|
@@ -96,9 +90,9 @@ if uploaded_file is not None:
|
|
| 96 |
output = get_label(model_prediction, model_version)
|
| 97 |
st.write(f"Recognized Character Sequence: {output}")
|
| 98 |
|
| 99 |
-
st.write("##
|
| 100 |
-
cols = st.columns(
|
| 101 |
for idx,(model_variant,captcha_path) in enumerate(example_captchas.items()):
|
| 102 |
-
col = cols[idx %
|
| 103 |
-
col.image(captcha_path,caption=f'
|
| 104 |
col.write(f"Model Version: {model_variant}")
|
|
|
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
import streamlit as st
|
| 10 |
|
|
|
|
| 11 |
TOKEN = os.getenv('hf_read_token')
|
| 12 |
repo_id = "Nischay103/captcha_recognition"
|
| 13 |
model_files = {
|
|
|
|
| 41 |
"11": "0123456789$"
|
| 42 |
}
|
| 43 |
|
|
|
|
| 44 |
models = {}
|
| 45 |
for key, model_file in model_files.items():
|
| 46 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_file, token=TOKEN)
|
| 47 |
models[key] = torch.jit.load(model_path)
|
| 48 |
|
|
|
|
| 49 |
def transform_image(image_path):
|
| 50 |
transform = T.Compose([T.ToTensor()])
|
| 51 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 57 |
image = image.unsqueeze(1)
|
| 58 |
return image
|
| 59 |
|
|
|
|
| 60 |
def get_label(model_prediction, model_version):
|
| 61 |
max_captcha_len, cls_dim_encoded = len_dim_pair[model_version]
|
| 62 |
_cls = char_sets[cls_dim_encoded]
|
|
|
|
| 69 |
lab += get_char
|
| 70 |
return lab
|
| 71 |
|
| 72 |
+
st.title("char-seq recognition from scene-image (captcha)")
|
| 73 |
+
st.write("recognize captchas using different models")
|
|
|
|
| 74 |
|
| 75 |
+
uploaded_file = st.file_uploader("choose a captcha image...", type=["jpg", "png"])
|
| 76 |
+
model_version = st.selectbox("model variant", list(model_files.keys()), index=0)
|
| 77 |
|
| 78 |
if uploaded_file is not None:
|
|
|
|
| 79 |
with open("temp_captcha_image.png", "wb") as f:
|
| 80 |
f.write(uploaded_file.getbuffer())
|
| 81 |
|
| 82 |
input_image_path = "temp_captcha_image.png"
|
| 83 |
+
st.image(input_image_path, caption='uploaded captcha image', use_column_width=True)
|
| 84 |
|
| 85 |
+
if st.button('recognize'):
|
| 86 |
input = transform_image(input_image_path)
|
| 87 |
model = models[model_version]
|
| 88 |
with torch.no_grad():
|
|
|
|
| 90 |
output = get_label(model_prediction, model_version)
|
| 91 |
st.write(f"Recognized Character Sequence: {output}")
|
| 92 |
|
| 93 |
+
st.write("## examples")
|
| 94 |
+
cols = st.columns(4)
|
| 95 |
for idx,(model_variant,captcha_path) in enumerate(example_captchas.items()):
|
| 96 |
+
col = cols[idx % 4]
|
| 97 |
+
col.image(captcha_path,caption=f'{captcha_path.split("/")[-1]}', width=100)
|
| 98 |
col.write(f"Model Version: {model_variant}")
|