| import io | |
| import time | |
| import os | |
| import re | |
| import numpy as np | |
| from PIL import Image, ImageFilter | |
| from cairosvg import svg2png | |
| from transformers import VisionEncoderDecoderModel, TrOCRProcessor | |
| import gradio as gr | |
| processor = TrOCRProcessor.from_pretrained("anuashok/ocr-captcha-v3") | |
| model = VisionEncoderDecoderModel.from_pretrained("anuashok/ocr-captcha-v3") | |
| os.makedirs("outputs", exist_ok=True) | |
| def _single_ocr_from_image(image: Image.Image) -> str: | |
| pixel_values = processor(image, return_tensors="pt").pixel_values | |
| generated_ids = model.generate(pixel_values) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| sanitized = re.sub(r'[^A-Z0-9]', '', generated_text).upper() | |
| return sanitized[:4] | |
| def solve_svg_captcha(svg_data: str) -> str: | |
| svg = svg_data or "" | |
| svg = re.sub(r'<style>.*?</style>', '', svg, flags=re.DOTALL) | |
| svg = svg.replace('file:///', '') | |
| svg = svg.replace('/app/', '') | |
| svg = re.sub(r'url\(["\']?\/?app\/[^)"\']*["\']?\)', 'url()', svg) | |
| svg_static = re.sub(r'<animateTransform\b[^>]*>(?:.*?</animateTransform>)?', '', svg, flags=re.DOTALL) | |
| rotate_re = re.compile(r'rotate\(\s*([+-]?\d+)\s*,\s*([0-9.]+)\s*,\s*([0-9.]+)\s*\)') | |
| matches = rotate_re.findall(svg_static) | |
| centers = [] | |
| seen = set() | |
| for _, cx, cy in matches: | |
| key = f"{cx},{cy}" | |
| if key not in seen: | |
| seen.add(key) | |
| centers.append((cx, cy)) | |
| if not centers: | |
| try: | |
| png_bytes = svg2png(bytestring=svg_static.encode('utf-8')) | |
| image = Image.open(io.BytesIO(png_bytes)).convert("RGBA") | |
| image = image.resize((600, 400)) | |
| background = Image.new("RGBA", image.size, (255, 255, 255)) | |
| combined = Image.alpha_composite(background, image).convert("RGB") | |
| return _single_ocr_from_image(combined) | |
| except Exception as e: | |
| print("OCR fallback error:", e) | |
| return "" | |
| centers = centers[:2] | |
| angle_step = 15 | |
| top_k = 2 | |
| best_angles = {} | |
| for cx, cy in centers: | |
| metrics = [] | |
| for angle in range(0, 360, angle_step): | |
| try: | |
| tmp = re.sub(rf'rotate\(\s*1\s*,\s*{re.escape(cx)}\s*,\s*{re.escape(cy)}\s*\)', f'rotate({angle}, {cx}, {cy})', svg_static) | |
| tmp = re.sub(rf'rotate\(\s*-1\s*,\s*{re.escape(cx)}\s*,\s*{re.escape(cy)}\s*\)', f'rotate(-{angle}, {cx}, {cy})', tmp) | |
| png_bytes = svg2png(bytestring=tmp.encode('utf-8')) | |
| img = Image.open(io.BytesIO(png_bytes)).convert('L') | |
| img = img.resize((600, 400)) | |
| img = img.filter(ImageFilter.GaussianBlur(radius=1)) | |
| edges = img.filter(ImageFilter.FIND_EDGES) | |
| arr = np.array(edges) | |
| edge_count = int((arr > 10).sum()) | |
| metrics.append((edge_count, angle)) | |
| except Exception: | |
| continue | |
| metrics.sort(key=lambda x: x[0]) | |
| picked = [m[1] for m in metrics[:top_k]] if metrics else [0] * top_k | |
| if len(picked) < top_k: | |
| picked += [picked[0]] * (top_k - len(picked)) | |
| best_angles[f"{cx},{cy}"] = picked | |
| combos = [] | |
| if len(centers) == 1: | |
| k = f"{centers[0][0]},{centers[0][1]}" | |
| a1, a2 = best_angles[k][:2] | |
| combos = [{k: a1}, {k: a2}, {k: a1}, {k: a2}] | |
| else: | |
| k0 = f"{centers[0][0]},{centers[0][1]}" | |
| k1 = f"{centers[1][0]},{centers[1][1]}" | |
| a1, a2 = best_angles[k0][:2] | |
| b1, b2 = best_angles[k1][:2] | |
| combos = [ | |
| {k0: a1, k1: b1}, | |
| {k0: a2, k1: b1}, | |
| {k0: a1, k1: b2}, | |
| {k0: a2, k1: b2}, | |
| ] | |
| images = [] | |
| for combo in combos: | |
| tmp = svg_static | |
| for key, angle in combo.items(): | |
| cx, cy = key.split(',') | |
| tmp = re.sub(rf'rotate\(\s*1\s*,\s*{re.escape(cx)}\s*,\s*{re.escape(cy)}\s*\)', f'rotate({angle}, {cx}, {cy})', tmp) | |
| tmp = re.sub(rf'rotate\(\s*-1\s*,\s*{re.escape(cx)}\s*,\s*{re.escape(cy)}\s*\)', f'rotate(-{angle}, {cx}, {cy})', tmp) | |
| try: | |
| png_bytes = svg2png(bytestring=tmp.encode('utf-8')) | |
| img = Image.open(io.BytesIO(png_bytes)).convert("RGBA") | |
| img = img.resize((600, 400)) | |
| background = Image.new("RGBA", img.size, (255, 255, 255)) | |
| combined = Image.alpha_composite(background, img).convert("RGB") | |
| images.append(combined) | |
| except Exception: | |
| continue | |
| ocr_results = [] | |
| for img in images: | |
| try: | |
| txt = _single_ocr_from_image(img) | |
| ocr_results.append(txt) | |
| except Exception: | |
| ocr_results.append("") | |
| for r in ocr_results: | |
| if len(r) == 4: | |
| return r | |
| if ocr_results: | |
| best = max(ocr_results, key=lambda x: len(x or "")) | |
| return best or "" | |
| return "" | |
| def predict(svgdata): | |
| if not svgdata: | |
| return "No SVG provided" | |
| if len(svgdata) > 50000: | |
| return "SVG too large" | |
| try: | |
| model_answer = solve_svg_captcha(svgdata) | |
| except Exception as e: | |
| print(f"Error in predict: {e}") | |
| return "Model could not predict" | |
| return model_answer or "Model could not predict" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Enter SVG data and receive model answer") | |
| svg_input = gr.Textbox(label="SVG Data", lines=10) | |
| predict_btn = gr.Button("Get Model Answer") | |
| model_answer = gr.Textbox(label="Model Answer", interactive=False) | |
| predict_btn.click(predict, inputs=[svg_input], outputs=[model_answer]) | |
| if __name__ == "__main__": | |
| demo.launch() |