Spaces:
Running
Running
Commit ·
1df75cb
1
Parent(s): 8bdfa24
read_image tool for text extraction from images
Browse files- all_code.txt +0 -0
- app/agents/adk_mathminds.py +55 -11
- app/core/ocr.py +77 -85
- tests/test_ocr_simple.py +37 -0
- tests/test_ocr_tool.py +49 -0
all_code.txt
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/agents/adk_mathminds.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
|
|
|
|
| 2 |
import logging
|
| 3 |
import asyncio
|
| 4 |
import base64
|
|
@@ -13,6 +14,8 @@ from app.core.settings import settings
|
|
| 13 |
from app.tools.web_scraper import WebScraper
|
| 14 |
from app.tools.symbolic_solver import SymbolicSolver
|
| 15 |
from app.tools.similarity_search import SimilarProblemFinder
|
|
|
|
|
|
|
| 16 |
from app.core.math_normalizer import MathQueryNormalizer
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
|
@@ -23,7 +26,7 @@ class MathMindsADKAgent:
|
|
| 23 |
Refined to match official Multitool Agent documentation patterns.
|
| 24 |
"""
|
| 25 |
|
| 26 |
-
def __init__(self, model_name: str = "gemini-2.5-
|
| 27 |
self.api_key = settings.GOOGLE_API_KEY
|
| 28 |
if not self.api_key:
|
| 29 |
logger.warning("No Google API Key found. Agent will fail.")
|
|
@@ -33,6 +36,8 @@ class MathMindsADKAgent:
|
|
| 33 |
self.symbolic_solver = SymbolicSolver()
|
| 34 |
self.normalizer = MathQueryNormalizer()
|
| 35 |
self.similar_finder = SimilarProblemFinder()
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Define Tools as simpler closures
|
| 38 |
# Docs pattern: simple functions, passed in a list.
|
|
@@ -82,20 +87,51 @@ class MathMindsADKAgent:
|
|
| 82 |
formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
|
| 83 |
return formatted
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
# Initialize Agent
|
| 86 |
# Using 'Agent' class as per official docs, passing functions directly.
|
| 87 |
self.agent = Agent(
|
| 88 |
name="math_minds_core",
|
| 89 |
model=model_name,
|
| 90 |
-
tools=[web_search, math_solver, find_similar_problems], # Passed directly as function list
|
| 91 |
instruction=(
|
| 92 |
"You are MathMinds AI, a helpful and precise mathematical assistant. "
|
| 93 |
-
"You
|
| 94 |
-
"
|
| 95 |
-
"
|
| 96 |
-
"
|
| 97 |
-
"
|
| 98 |
-
"
|
|
|
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
)
|
| 101 |
|
|
@@ -141,11 +177,19 @@ class MathMindsADKAgent:
|
|
| 141 |
|
| 142 |
if image_data:
|
| 143 |
try:
|
| 144 |
-
|
| 145 |
-
mime_type = "image/png"
|
| 146 |
if image_data.startswith("/9j/"):
|
| 147 |
mime_type = "image/jpeg"
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
|
| 150 |
logger.info("Attached image to agent request.")
|
| 151 |
except Exception as e:
|
|
|
|
| 1 |
|
| 2 |
+
|
| 3 |
import logging
|
| 4 |
import asyncio
|
| 5 |
import base64
|
|
|
|
| 14 |
from app.tools.web_scraper import WebScraper
|
| 15 |
from app.tools.symbolic_solver import SymbolicSolver
|
| 16 |
from app.tools.similarity_search import SimilarProblemFinder
|
| 17 |
+
from app.core.ocr import OCRProcessor
|
| 18 |
+
from app.tools.vision_analyzer import VisionAnalyzer
|
| 19 |
from app.core.math_normalizer import MathQueryNormalizer
|
| 20 |
|
| 21 |
logger = logging.getLogger(__name__)
|
|
|
|
| 26 |
Refined to match official Multitool Agent documentation patterns.
|
| 27 |
"""
|
| 28 |
|
| 29 |
+
def __init__(self, model_name: str = "gemini-2.5-pro"):
|
| 30 |
self.api_key = settings.GOOGLE_API_KEY
|
| 31 |
if not self.api_key:
|
| 32 |
logger.warning("No Google API Key found. Agent will fail.")
|
|
|
|
| 36 |
self.symbolic_solver = SymbolicSolver()
|
| 37 |
self.normalizer = MathQueryNormalizer()
|
| 38 |
self.similar_finder = SimilarProblemFinder()
|
| 39 |
+
self.ocr = OCRProcessor()
|
| 40 |
+
self.vision_analyzer = VisionAnalyzer()
|
| 41 |
|
| 42 |
# Define Tools as simpler closures
|
| 43 |
# Docs pattern: simple functions, passed in a list.
|
|
|
|
| 87 |
formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
|
| 88 |
return formatted
|
| 89 |
|
| 90 |
+
def read_image(image_data: str) -> str:
|
| 91 |
+
"""
|
| 92 |
+
Useful for reading text, numbers, or equations from an image when you cannot see it clearly or need the exact text.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
image_data: The base64 string of the image.
|
| 96 |
+
"""
|
| 97 |
+
try:
|
| 98 |
+
text = self.ocr.extract_text(image_data=image_data)
|
| 99 |
+
return text if text else "No text found in image."
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return f"Error reading image: {str(e)}"
|
| 102 |
+
|
| 103 |
+
async def analyze_image(image_data: str, focus: str = "") -> str:
|
| 104 |
+
"""
|
| 105 |
+
Analyzes an image mathematically: extracts equations, counts objects, describes graphs, etc.
|
| 106 |
+
Use this when the user uploaded an image and wants to count items or understand the visual content.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
image_data: The base64 string of the image.
|
| 110 |
+
focus: Option string to focus analysis (e.g. "count red balls").
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
result = self.vision_analyzer.analyze(image_data, focus)
|
| 114 |
+
return str(result)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
return f"Image analysis failed: {str(e)}"
|
| 117 |
+
|
| 118 |
# Initialize Agent
|
| 119 |
# Using 'Agent' class as per official docs, passing functions directly.
|
| 120 |
self.agent = Agent(
|
| 121 |
name="math_minds_core",
|
| 122 |
model=model_name,
|
| 123 |
+
tools=[web_search, math_solver, find_similar_problems, read_image, analyze_image], # Passed directly as function list
|
| 124 |
instruction=(
|
| 125 |
"You are MathMinds AI, a helpful and precise mathematical assistant. "
|
| 126 |
+
"You can receive BOTH text instructions AND images in the same query. "
|
| 127 |
+
"When an image is provided, ALWAYS analyze it first — describe what you see, "
|
| 128 |
+
"extract equations if present, count objects if it's a probability/statistics question, "
|
| 129 |
+
"or interpret graphs/charts/diagrams mathematically. "
|
| 130 |
+
"Then combine the image analysis with the text prompt to give a complete answer. "
|
| 131 |
+
"Use tools only when necessary (e.g. 'Math Solver' for symbolic work, 'Web Search' for facts). "
|
| 132 |
+
"Use 'Read Image' to extract text from images if it's blurry or you need exact wording. "
|
| 133 |
+
"Use 'Analyze Image' to count objects or detect items. "
|
| 134 |
+
"Always explain your steps clearly and show reasoning."
|
| 135 |
)
|
| 136 |
)
|
| 137 |
|
|
|
|
| 177 |
|
| 178 |
if image_data:
|
| 179 |
try:
|
| 180 |
+
# Better MIME type detection
|
|
|
|
| 181 |
if image_data.startswith("/9j/"):
|
| 182 |
mime_type = "image/jpeg"
|
| 183 |
+
elif image_data.startswith("iVBORw"):
|
| 184 |
+
mime_type = "image/png"
|
| 185 |
+
elif image_data.startswith("R0lGOD"):
|
| 186 |
+
mime_type = "image/gif"
|
| 187 |
+
elif image_data.startswith("UklGR"):
|
| 188 |
+
mime_type = "image/webp"
|
| 189 |
+
else:
|
| 190 |
+
mime_type = "image/png" # Default fallback
|
| 191 |
+
|
| 192 |
+
img_bytes = base64.b64decode(image_data)
|
| 193 |
parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
|
| 194 |
logger.info("Attached image to agent request.")
|
| 195 |
except Exception as e:
|
app/core/ocr.py
CHANGED
|
@@ -1,21 +1,90 @@
|
|
|
|
|
| 1 |
import base64
|
| 2 |
import requests
|
| 3 |
import io
|
| 4 |
import logging
|
| 5 |
-
|
| 6 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
class OCRProcessor:
|
| 11 |
"""
|
| 12 |
-
Handles
|
| 13 |
-
Note: PaddleOCR has been removed. This class now acts as an image helper.
|
| 14 |
"""
|
| 15 |
|
|
|
|
|
|
|
| 16 |
def __init__(self, max_size_bytes: int = 5 * 1024 * 1024): # 5MB limit
|
| 17 |
self.max_size = max_size_bytes
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def optimize_base64(self, b64_string: str) -> str:
|
| 21 |
"""
|
|
@@ -23,114 +92,37 @@ class OCRProcessor:
|
|
| 23 |
Returns optimized base64 string.
|
| 24 |
"""
|
| 25 |
try:
|
| 26 |
-
# Basic strip
|
| 27 |
if ";base64," in b64_string:
|
| 28 |
-
|
| 29 |
else:
|
| 30 |
-
header = None
|
| 31 |
data = b64_string
|
| 32 |
|
| 33 |
img_data = base64.b64decode(data)
|
| 34 |
img = Image.open(io.BytesIO(img_data))
|
| 35 |
|
| 36 |
-
# Resize if too large
|
| 37 |
max_dim = 1024
|
| 38 |
if max(img.size) > max_dim:
|
| 39 |
img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
|
| 40 |
|
| 41 |
-
# Convert to JPEG for compression (if RGBA, convert to RGB)
|
| 42 |
if img.mode in ('RGBA', 'P'):
|
| 43 |
img = img.convert('RGB')
|
| 44 |
|
| 45 |
buffer = io.BytesIO()
|
| 46 |
-
# Quality 85 is good balance
|
| 47 |
img.save(buffer, format="JPEG", quality=85)
|
| 48 |
-
|
| 49 |
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 50 |
except Exception as e:
|
| 51 |
-
logger.warning(f"Image optimization failed
|
| 52 |
return b64_string
|
| 53 |
|
| 54 |
def download_image_as_base64(self, url: str) -> Optional[str]:
|
| 55 |
-
"""
|
| 56 |
-
Download image from URL and return as base64 string.
|
| 57 |
-
"""
|
| 58 |
try:
|
| 59 |
response = requests.get(url, timeout=10, stream=True)
|
| 60 |
response.raise_for_status()
|
| 61 |
-
|
| 62 |
-
# Size check
|
| 63 |
if len(response.content) > self.max_size:
|
| 64 |
-
logger.warning(f"Downloaded image bytes {len(response.content)} exceed limit.")
|
| 65 |
return None
|
| 66 |
-
|
| 67 |
-
# Optimize immediately
|
| 68 |
b64 = base64.b64encode(response.content).decode('utf-8')
|
| 69 |
return self.optimize_base64(b64)
|
| 70 |
-
|
| 71 |
except Exception as e:
|
| 72 |
logger.error(f"Image download failed: {e}")
|
| 73 |
return None
|
| 74 |
-
|
| 75 |
-
def _preprocess_image(self, img: Image.Image) -> Image.Image:
|
| 76 |
-
"""
|
| 77 |
-
Applies preprocessing to improve image quality for Vision model.
|
| 78 |
-
- Grayscale conversion
|
| 79 |
-
- Contrast enhancement
|
| 80 |
-
- Binarization (Thresholding)
|
| 81 |
-
"""
|
| 82 |
-
try:
|
| 83 |
-
# 1. Convert to grayscale
|
| 84 |
-
img = img.convert('L')
|
| 85 |
-
|
| 86 |
-
# 2. Enhance contrast
|
| 87 |
-
enhancer = ImageEnhance.Contrast(img)
|
| 88 |
-
img = enhancer.enhance(2.0)
|
| 89 |
-
|
| 90 |
-
# 3. Apply thresholding (binarization)
|
| 91 |
-
# This makes the image pure black and white, removing noise
|
| 92 |
-
img = img.point(lambda x: 0 if x < 128 else 255, '1')
|
| 93 |
-
|
| 94 |
-
return img
|
| 95 |
-
except Exception as e:
|
| 96 |
-
logger.warning(f"Image preprocessing failed, using original: {e}")
|
| 97 |
-
return img
|
| 98 |
-
|
| 99 |
-
def _process_image_data(self, image_bytes: bytes) -> Optional[str]:
|
| 100 |
-
"""
|
| 101 |
-
Validate image format.
|
| 102 |
-
Returns dummy string or None.
|
| 103 |
-
DEPRECATED: Used to do OCR. Now just validates.
|
| 104 |
-
"""
|
| 105 |
-
# 1. Size Check
|
| 106 |
-
if len(image_bytes) > self.max_size:
|
| 107 |
-
logger.warning("Image data exceeds size limit.")
|
| 108 |
-
return None
|
| 109 |
-
|
| 110 |
-
# 2. Format Validation (using Pillow)
|
| 111 |
-
try:
|
| 112 |
-
img = Image.open(io.BytesIO(image_bytes))
|
| 113 |
-
img.verify() # Verify it's an image
|
| 114 |
-
|
| 115 |
-
# Re-open for processing (verify closes the file)
|
| 116 |
-
img = Image.open(io.BytesIO(image_bytes))
|
| 117 |
-
|
| 118 |
-
if img.format.upper() not in ('JPEG', 'JPG', 'PNG', 'BMP', 'WEBP'):
|
| 119 |
-
logger.warning(f"Unsupported image format: {img.format}")
|
| 120 |
-
return None
|
| 121 |
-
|
| 122 |
-
return "VALID_IMAGE"
|
| 123 |
-
|
| 124 |
-
except Exception as e:
|
| 125 |
-
logger.warning(f"Invalid image file: {e}")
|
| 126 |
-
return None
|
| 127 |
-
|
| 128 |
-
# Legacy methods stubbed out or removed.
|
| 129 |
-
# process_base64 and process_url were used for text extraction.
|
| 130 |
-
# Calling them now should return None to indicate no text extracted.
|
| 131 |
-
|
| 132 |
-
def process_base64(self, b64_string: str) -> Optional[str]:
|
| 133 |
-
return None
|
| 134 |
-
|
| 135 |
-
def process_url(self, url: str) -> Optional[str]:
|
| 136 |
-
return None
|
|
|
|
| 1 |
+
|
| 2 |
import base64
|
| 3 |
import requests
|
| 4 |
import io
|
| 5 |
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Optional, List
|
| 8 |
+
from PIL import Image, ImageEnhance
|
| 9 |
+
try:
|
| 10 |
+
from paddleocr import PaddleOCR
|
| 11 |
+
except ImportError:
|
| 12 |
+
PaddleOCR = None
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
class OCRProcessor:
|
| 17 |
"""
|
| 18 |
+
Handles OCR text extraction using PaddleOCR and image preprocessing.
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
_os_instance = None # Singleton for OCR engine
|
| 22 |
+
|
| 23 |
def __init__(self, max_size_bytes: int = 5 * 1024 * 1024): # 5MB limit
|
| 24 |
self.max_size = max_size_bytes
|
| 25 |
+
self.ocr_engine = None
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def engine(self):
|
| 29 |
+
"""Lazy load PaddleOCR engine."""
|
| 30 |
+
if self.ocr_engine is None:
|
| 31 |
+
if PaddleOCR:
|
| 32 |
+
logger.info("Initializing PaddleOCR engine...")
|
| 33 |
+
# deterministic=True ensures consistent results
|
| 34 |
+
self.ocr_engine = PaddleOCR(use_angle_cls=True, lang='en')
|
| 35 |
+
else:
|
| 36 |
+
logger.error("PaddleOCR not installed.")
|
| 37 |
+
return None
|
| 38 |
+
return self.ocr_engine
|
| 39 |
+
|
| 40 |
+
def extract_text(self, headers_b64: Optional[str] = None, image_data: Optional[str] = None) -> str:
|
| 41 |
+
"""
|
| 42 |
+
Extract text from base64 image data.
|
| 43 |
+
Arg 'headers_b64' is for backward compat/legacy signature matching if any,
|
| 44 |
+
but we expect 'image_data' (base64 string).
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
image_data: Base64 string of the image.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Extracted text string or empty string on failure.
|
| 51 |
+
"""
|
| 52 |
+
# Handle positional args if someone calls extract_text(b64)
|
| 53 |
+
target_b64 = image_data or headers_b64
|
| 54 |
+
if not target_b64:
|
| 55 |
+
return ""
|
| 56 |
+
|
| 57 |
+
if not self.engine:
|
| 58 |
+
return "OCR Engine Unavailable"
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
# 1. Decode Base64 to Array
|
| 62 |
+
if ";base64," in target_b64:
|
| 63 |
+
_, target_b64 = target_b64.split(";base64,")
|
| 64 |
+
|
| 65 |
+
img_bytes = base64.b64decode(target_b64)
|
| 66 |
+
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
| 67 |
+
img_arr = np.array(img)
|
| 68 |
+
|
| 69 |
+
# 2. Run OCR
|
| 70 |
+
result = self.engine.ocr(img_arr, cls=True)
|
| 71 |
+
|
| 72 |
+
# 3. Parse Results
|
| 73 |
+
extracted_lines = []
|
| 74 |
+
if result and result[0]:
|
| 75 |
+
for line in result[0]:
|
| 76 |
+
text = line[1][0]
|
| 77 |
+
confidence = line[1][1]
|
| 78 |
+
if confidence > 0.5: # Confidence threshold
|
| 79 |
+
extracted_lines.append(text)
|
| 80 |
+
|
| 81 |
+
full_text = "\n".join(extracted_lines)
|
| 82 |
+
logger.info(f"OCR extracted {len(full_text)} chars.")
|
| 83 |
+
return full_text
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"OCR Failed: {e}")
|
| 87 |
+
return f"Error reading image: {e}"
|
| 88 |
|
| 89 |
def optimize_base64(self, b64_string: str) -> str:
|
| 90 |
"""
|
|
|
|
| 92 |
Returns optimized base64 string.
|
| 93 |
"""
|
| 94 |
try:
|
|
|
|
| 95 |
if ";base64," in b64_string:
|
| 96 |
+
_, data = b64_string.split(";base64,")
|
| 97 |
else:
|
|
|
|
| 98 |
data = b64_string
|
| 99 |
|
| 100 |
img_data = base64.b64decode(data)
|
| 101 |
img = Image.open(io.BytesIO(img_data))
|
| 102 |
|
|
|
|
| 103 |
max_dim = 1024
|
| 104 |
if max(img.size) > max_dim:
|
| 105 |
img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
|
| 106 |
|
|
|
|
| 107 |
if img.mode in ('RGBA', 'P'):
|
| 108 |
img = img.convert('RGB')
|
| 109 |
|
| 110 |
buffer = io.BytesIO()
|
|
|
|
| 111 |
img.save(buffer, format="JPEG", quality=85)
|
|
|
|
| 112 |
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 113 |
except Exception as e:
|
| 114 |
+
logger.warning(f"Image optimization failed: {e}")
|
| 115 |
return b64_string
|
| 116 |
|
| 117 |
def download_image_as_base64(self, url: str) -> Optional[str]:
|
| 118 |
+
"""Download image from URL and return as base64 string."""
|
|
|
|
|
|
|
| 119 |
try:
|
| 120 |
response = requests.get(url, timeout=10, stream=True)
|
| 121 |
response.raise_for_status()
|
|
|
|
|
|
|
| 122 |
if len(response.content) > self.max_size:
|
|
|
|
| 123 |
return None
|
|
|
|
|
|
|
| 124 |
b64 = base64.b64encode(response.content).decode('utf-8')
|
| 125 |
return self.optimize_base64(b64)
|
|
|
|
| 126 |
except Exception as e:
|
| 127 |
logger.error(f"Image download failed: {e}")
|
| 128 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_ocr_simple.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
sys.path.insert(0, os.getcwd())
|
| 5 |
+
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
+
from PIL import Image, ImageDraw
|
| 9 |
+
from app.core.ocr import OCRProcessor
|
| 10 |
+
|
| 11 |
+
def create_test_image_b64(text: str) -> str:
|
| 12 |
+
img = Image.new('RGB', (400, 100), color=(255, 255, 255))
|
| 13 |
+
d = ImageDraw.Draw(img)
|
| 14 |
+
d.text((10, 40), text, fill=(0, 0, 0))
|
| 15 |
+
buffer = io.BytesIO()
|
| 16 |
+
img.save(buffer, format="PNG")
|
| 17 |
+
return base64.b64encode(buffer.getvalue()).decode()
|
| 18 |
+
|
| 19 |
+
def test_ocr_direct():
|
| 20 |
+
print("Initializing OCRProcessor...")
|
| 21 |
+
ocr = OCRProcessor()
|
| 22 |
+
|
| 23 |
+
text = "Hello OCR World"
|
| 24 |
+
b64 = create_test_image_b64(text)
|
| 25 |
+
|
| 26 |
+
print(f"Extracting text from image with '{text}'...")
|
| 27 |
+
result = ocr.extract_text(image_data=b64)
|
| 28 |
+
|
| 29 |
+
print(f"Result: {result}")
|
| 30 |
+
|
| 31 |
+
if text in result:
|
| 32 |
+
print("SUCCESS: OCR worked correctly.")
|
| 33 |
+
else:
|
| 34 |
+
print("FAILURE: Text not found.")
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
test_ocr_direct()
|
tests/test_ocr_tool.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pytest
|
| 3 |
+
import asyncio
|
| 4 |
+
import base64
|
| 5 |
+
import io
|
| 6 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 7 |
+
from app.agents.adk_mathminds import MathMindsADKAgent
|
| 8 |
+
|
| 9 |
+
def create_test_image_b64(text: str) -> str:
|
| 10 |
+
"""Creates a simple image with text and returns base64 string."""
|
| 11 |
+
img = Image.new('RGB', (400, 100), color=(255, 255, 255))
|
| 12 |
+
d = ImageDraw.Draw(img)
|
| 13 |
+
# default font or simple drawing
|
| 14 |
+
d.text((10, 40), text, fill=(0, 0, 0))
|
| 15 |
+
|
| 16 |
+
buffer = io.BytesIO()
|
| 17 |
+
img.save(buffer, format="PNG")
|
| 18 |
+
return base64.b64encode(buffer.getvalue()).decode()
|
| 19 |
+
|
| 20 |
+
@pytest.mark.asyncio
|
| 21 |
+
async def test_ocr_tool_usage():
|
| 22 |
+
"""
|
| 23 |
+
Verifies that the ADK agent can use the 'read_image' tool to extract text.
|
| 24 |
+
"""
|
| 25 |
+
agent = MathMindsADKAgent()
|
| 26 |
+
|
| 27 |
+
secret_text = "The secret number is 999."
|
| 28 |
+
image_b64 = create_test_image_b64(secret_text)
|
| 29 |
+
|
| 30 |
+
print("\n--- Starting OCR Tool Test ---")
|
| 31 |
+
print(f"Generated image with text: '{secret_text}'")
|
| 32 |
+
|
| 33 |
+
# Ask the agent to read it
|
| 34 |
+
# We specifically ask to "read the text" to encourage tool usage
|
| 35 |
+
# over just vision model (though both might work).
|
| 36 |
+
response = await agent.solve(
|
| 37 |
+
problem="What is the secret number written in this image? Use your read_image tool if needed.",
|
| 38 |
+
image_data=image_b64,
|
| 39 |
+
session_id="test_ocr_session",
|
| 40 |
+
user_id="test_user"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print(f"Agent Response: {response}")
|
| 44 |
+
|
| 45 |
+
assert "999" in response, f"Agent failed to extract the number. Response: {response}"
|
| 46 |
+
print("\nSUCCESS: OCR Tool verified!")
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
asyncio.run(test_ocr_tool_usage())
|