ghadgemadhuri92 commited on
Commit
1df75cb
·
1 Parent(s): 8bdfa24

read_image tool for text extraction from images

Browse files
all_code.txt DELETED
The diff for this file is too large to render. See raw diff
 
app/agents/adk_mathminds.py CHANGED
@@ -1,4 +1,5 @@
1
 
 
2
  import logging
3
  import asyncio
4
  import base64
@@ -13,6 +14,8 @@ from app.core.settings import settings
13
  from app.tools.web_scraper import WebScraper
14
  from app.tools.symbolic_solver import SymbolicSolver
15
  from app.tools.similarity_search import SimilarProblemFinder
 
 
16
  from app.core.math_normalizer import MathQueryNormalizer
17
 
18
  logger = logging.getLogger(__name__)
@@ -23,7 +26,7 @@ class MathMindsADKAgent:
23
  Refined to match official Multitool Agent documentation patterns.
24
  """
25
 
26
- def __init__(self, model_name: str = "gemini-2.5-flash"):
27
  self.api_key = settings.GOOGLE_API_KEY
28
  if not self.api_key:
29
  logger.warning("No Google API Key found. Agent will fail.")
@@ -33,6 +36,8 @@ class MathMindsADKAgent:
33
  self.symbolic_solver = SymbolicSolver()
34
  self.normalizer = MathQueryNormalizer()
35
  self.similar_finder = SimilarProblemFinder()
 
 
36
 
37
  # Define Tools as simpler closures
38
  # Docs pattern: simple functions, passed in a list.
@@ -82,20 +87,51 @@ class MathMindsADKAgent:
82
  formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
83
  return formatted
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Initialize Agent
86
  # Using 'Agent' class as per official docs, passing functions directly.
87
  self.agent = Agent(
88
  name="math_minds_core",
89
  model=model_name,
90
- tools=[web_search, math_solver, find_similar_problems], # Passed directly as function list
91
  instruction=(
92
  "You are MathMinds AI, a helpful and precise mathematical assistant. "
93
- "You have access to tools for solving symbolic math problems, searching the web, and finding similar solved problems. "
94
- "If an image is provided, analyze it mathematically. "
95
- "Use 'Math Solver' for distinct math problems (equations, calculus, etc.). "
96
- "Use 'Web Search' for real-world data (prices, weather, facts). "
97
- "Use 'Find Similar Problems' to look up examples if you are unsure how to solve a problem. "
98
- "Always explain your steps clearly."
 
 
 
99
  )
100
  )
101
 
@@ -141,11 +177,19 @@ class MathMindsADKAgent:
141
 
142
  if image_data:
143
  try:
144
- img_bytes = base64.b64decode(image_data)
145
- mime_type = "image/png"
146
  if image_data.startswith("/9j/"):
147
  mime_type = "image/jpeg"
148
-
 
 
 
 
 
 
 
 
 
149
  parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
150
  logger.info("Attached image to agent request.")
151
  except Exception as e:
 
1
 
2
+
3
  import logging
4
  import asyncio
5
  import base64
 
14
  from app.tools.web_scraper import WebScraper
15
  from app.tools.symbolic_solver import SymbolicSolver
16
  from app.tools.similarity_search import SimilarProblemFinder
17
+ from app.core.ocr import OCRProcessor
18
+ from app.tools.vision_analyzer import VisionAnalyzer
19
  from app.core.math_normalizer import MathQueryNormalizer
20
 
21
  logger = logging.getLogger(__name__)
 
26
  Refined to match official Multitool Agent documentation patterns.
27
  """
28
 
29
+ def __init__(self, model_name: str = "gemini-2.5-pro"):
30
  self.api_key = settings.GOOGLE_API_KEY
31
  if not self.api_key:
32
  logger.warning("No Google API Key found. Agent will fail.")
 
36
  self.symbolic_solver = SymbolicSolver()
37
  self.normalizer = MathQueryNormalizer()
38
  self.similar_finder = SimilarProblemFinder()
39
+ self.ocr = OCRProcessor()
40
+ self.vision_analyzer = VisionAnalyzer()
41
 
42
  # Define Tools as simpler closures
43
  # Docs pattern: simple functions, passed in a list.
 
87
  formatted += f"Problem: {item.get('problem_text')}\nSolution: {item.get('solution_text')}\n---\n"
88
  return formatted
89
 
90
+ def read_image(image_data: str) -> str:
91
+ """
92
+ Useful for reading text, numbers, or equations from an image when you cannot see it clearly or need the exact text.
93
+
94
+ Args:
95
+ image_data: The base64 string of the image.
96
+ """
97
+ try:
98
+ text = self.ocr.extract_text(image_data=image_data)
99
+ return text if text else "No text found in image."
100
+ except Exception as e:
101
+ return f"Error reading image: {str(e)}"
102
+
103
+ async def analyze_image(image_data: str, focus: str = "") -> str:
104
+ """
105
+ Analyzes an image mathematically: extracts equations, counts objects, describes graphs, etc.
106
+ Use this when the user uploaded an image and wants to count items or understand the visual content.
107
+
108
+ Args:
109
+ image_data: The base64 string of the image.
110
+ focus: Option string to focus analysis (e.g. "count red balls").
111
+ """
112
+ try:
113
+ result = self.vision_analyzer.analyze(image_data, focus)
114
+ return str(result)
115
+ except Exception as e:
116
+ return f"Image analysis failed: {str(e)}"
117
+
118
  # Initialize Agent
119
  # Using 'Agent' class as per official docs, passing functions directly.
120
  self.agent = Agent(
121
  name="math_minds_core",
122
  model=model_name,
123
+ tools=[web_search, math_solver, find_similar_problems, read_image, analyze_image], # Passed directly as function list
124
  instruction=(
125
  "You are MathMinds AI, a helpful and precise mathematical assistant. "
126
+ "You can receive BOTH text instructions AND images in the same query. "
127
+ "When an image is provided, ALWAYS analyze it first — describe what you see, "
128
+ "extract equations if present, count objects if it's a probability/statistics question, "
129
+ "or interpret graphs/charts/diagrams mathematically. "
130
+ "Then combine the image analysis with the text prompt to give a complete answer. "
131
+ "Use tools only when necessary (e.g. 'Math Solver' for symbolic work, 'Web Search' for facts). "
132
+ "Use 'Read Image' to extract text from images if it's blurry or you need exact wording. "
133
+ "Use 'Analyze Image' to count objects or detect items. "
134
+ "Always explain your steps clearly and show reasoning."
135
  )
136
  )
137
 
 
177
 
178
  if image_data:
179
  try:
180
+ # Better MIME type detection
 
181
  if image_data.startswith("/9j/"):
182
  mime_type = "image/jpeg"
183
+ elif image_data.startswith("iVBORw"):
184
+ mime_type = "image/png"
185
+ elif image_data.startswith("R0lGOD"):
186
+ mime_type = "image/gif"
187
+ elif image_data.startswith("UklGR"):
188
+ mime_type = "image/webp"
189
+ else:
190
+ mime_type = "image/png" # Default fallback
191
+
192
+ img_bytes = base64.b64decode(image_data)
193
  parts.append(types.Part.from_bytes(data=img_bytes, mime_type=mime_type))
194
  logger.info("Attached image to agent request.")
195
  except Exception as e:
app/core/ocr.py CHANGED
@@ -1,21 +1,90 @@
 
1
  import base64
2
  import requests
3
  import io
4
  import logging
5
- from typing import Optional
6
- from PIL import Image, ImageEnhance, ImageOps
 
 
 
 
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
  class OCRProcessor:
11
  """
12
- Handles image validation and download.
13
- Note: PaddleOCR has been removed. This class now acts as an image helper.
14
  """
15
 
 
 
16
  def __init__(self, max_size_bytes: int = 5 * 1024 * 1024): # 5MB limit
17
  self.max_size = max_size_bytes
18
- # No OCR engine init needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def optimize_base64(self, b64_string: str) -> str:
21
  """
@@ -23,114 +92,37 @@ class OCRProcessor:
23
  Returns optimized base64 string.
24
  """
25
  try:
26
- # Basic strip
27
  if ";base64," in b64_string:
28
- header, data = b64_string.split(";base64,")
29
  else:
30
- header = None
31
  data = b64_string
32
 
33
  img_data = base64.b64decode(data)
34
  img = Image.open(io.BytesIO(img_data))
35
 
36
- # Resize if too large
37
  max_dim = 1024
38
  if max(img.size) > max_dim:
39
  img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
40
 
41
- # Convert to JPEG for compression (if RGBA, convert to RGB)
42
  if img.mode in ('RGBA', 'P'):
43
  img = img.convert('RGB')
44
 
45
  buffer = io.BytesIO()
46
- # Quality 85 is good balance
47
  img.save(buffer, format="JPEG", quality=85)
48
-
49
  return base64.b64encode(buffer.getvalue()).decode('utf-8')
50
  except Exception as e:
51
- logger.warning(f"Image optimization failed, using original: {e}")
52
  return b64_string
53
 
54
  def download_image_as_base64(self, url: str) -> Optional[str]:
55
- """
56
- Download image from URL and return as base64 string.
57
- """
58
  try:
59
  response = requests.get(url, timeout=10, stream=True)
60
  response.raise_for_status()
61
-
62
- # Size check
63
  if len(response.content) > self.max_size:
64
- logger.warning(f"Downloaded image bytes {len(response.content)} exceed limit.")
65
  return None
66
-
67
- # Optimize immediately
68
  b64 = base64.b64encode(response.content).decode('utf-8')
69
  return self.optimize_base64(b64)
70
-
71
  except Exception as e:
72
  logger.error(f"Image download failed: {e}")
73
  return None
74
-
75
- def _preprocess_image(self, img: Image.Image) -> Image.Image:
76
- """
77
- Applies preprocessing to improve image quality for Vision model.
78
- - Grayscale conversion
79
- - Contrast enhancement
80
- - Binarization (Thresholding)
81
- """
82
- try:
83
- # 1. Convert to grayscale
84
- img = img.convert('L')
85
-
86
- # 2. Enhance contrast
87
- enhancer = ImageEnhance.Contrast(img)
88
- img = enhancer.enhance(2.0)
89
-
90
- # 3. Apply thresholding (binarization)
91
- # This makes the image pure black and white, removing noise
92
- img = img.point(lambda x: 0 if x < 128 else 255, '1')
93
-
94
- return img
95
- except Exception as e:
96
- logger.warning(f"Image preprocessing failed, using original: {e}")
97
- return img
98
-
99
- def _process_image_data(self, image_bytes: bytes) -> Optional[str]:
100
- """
101
- Validate image format.
102
- Returns dummy string or None.
103
- DEPRECATED: Used to do OCR. Now just validates.
104
- """
105
- # 1. Size Check
106
- if len(image_bytes) > self.max_size:
107
- logger.warning("Image data exceeds size limit.")
108
- return None
109
-
110
- # 2. Format Validation (using Pillow)
111
- try:
112
- img = Image.open(io.BytesIO(image_bytes))
113
- img.verify() # Verify it's an image
114
-
115
- # Re-open for processing (verify closes the file)
116
- img = Image.open(io.BytesIO(image_bytes))
117
-
118
- if img.format.upper() not in ('JPEG', 'JPG', 'PNG', 'BMP', 'WEBP'):
119
- logger.warning(f"Unsupported image format: {img.format}")
120
- return None
121
-
122
- return "VALID_IMAGE"
123
-
124
- except Exception as e:
125
- logger.warning(f"Invalid image file: {e}")
126
- return None
127
-
128
- # Legacy methods stubbed out or removed.
129
- # process_base64 and process_url were used for text extraction.
130
- # Calling them now should return None to indicate no text extracted.
131
-
132
- def process_base64(self, b64_string: str) -> Optional[str]:
133
- return None
134
-
135
- def process_url(self, url: str) -> Optional[str]:
136
- return None
 
1
+
2
  import base64
3
  import requests
4
  import io
5
  import logging
6
+ import numpy as np
7
+ from typing import Optional, List
8
+ from PIL import Image, ImageEnhance
9
+ try:
10
+ from paddleocr import PaddleOCR
11
+ except ImportError:
12
+ PaddleOCR = None
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
  class OCRProcessor:
17
  """
18
+ Handles OCR text extraction using PaddleOCR and image preprocessing.
 
19
  """
20
 
21
+ _os_instance = None # Singleton for OCR engine
22
+
23
  def __init__(self, max_size_bytes: int = 5 * 1024 * 1024): # 5MB limit
24
  self.max_size = max_size_bytes
25
+ self.ocr_engine = None
26
+
27
+ @property
28
+ def engine(self):
29
+ """Lazy load PaddleOCR engine."""
30
+ if self.ocr_engine is None:
31
+ if PaddleOCR:
32
+ logger.info("Initializing PaddleOCR engine...")
33
+ # deterministic=True ensures consistent results
34
+ self.ocr_engine = PaddleOCR(use_angle_cls=True, lang='en')
35
+ else:
36
+ logger.error("PaddleOCR not installed.")
37
+ return None
38
+ return self.ocr_engine
39
+
40
+ def extract_text(self, headers_b64: Optional[str] = None, image_data: Optional[str] = None) -> str:
41
+ """
42
+ Extract text from base64 image data.
43
+ Arg 'headers_b64' is for backward compat/legacy signature matching if any,
44
+ but we expect 'image_data' (base64 string).
45
+
46
+ Args:
47
+ image_data: Base64 string of the image.
48
+
49
+ Returns:
50
+ Extracted text string or empty string on failure.
51
+ """
52
+ # Handle positional args if someone calls extract_text(b64)
53
+ target_b64 = image_data or headers_b64
54
+ if not target_b64:
55
+ return ""
56
+
57
+ if not self.engine:
58
+ return "OCR Engine Unavailable"
59
+
60
+ try:
61
+ # 1. Decode Base64 to Array
62
+ if ";base64," in target_b64:
63
+ _, target_b64 = target_b64.split(";base64,")
64
+
65
+ img_bytes = base64.b64decode(target_b64)
66
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
67
+ img_arr = np.array(img)
68
+
69
+ # 2. Run OCR
70
+ result = self.engine.ocr(img_arr, cls=True)
71
+
72
+ # 3. Parse Results
73
+ extracted_lines = []
74
+ if result and result[0]:
75
+ for line in result[0]:
76
+ text = line[1][0]
77
+ confidence = line[1][1]
78
+ if confidence > 0.5: # Confidence threshold
79
+ extracted_lines.append(text)
80
+
81
+ full_text = "\n".join(extracted_lines)
82
+ logger.info(f"OCR extracted {len(full_text)} chars.")
83
+ return full_text
84
+
85
+ except Exception as e:
86
+ logger.error(f"OCR Failed: {e}")
87
+ return f"Error reading image: {e}"
88
 
89
  def optimize_base64(self, b64_string: str) -> str:
90
  """
 
92
  Returns optimized base64 string.
93
  """
94
  try:
 
95
  if ";base64," in b64_string:
96
+ _, data = b64_string.split(";base64,")
97
  else:
 
98
  data = b64_string
99
 
100
  img_data = base64.b64decode(data)
101
  img = Image.open(io.BytesIO(img_data))
102
 
 
103
  max_dim = 1024
104
  if max(img.size) > max_dim:
105
  img.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
106
 
 
107
  if img.mode in ('RGBA', 'P'):
108
  img = img.convert('RGB')
109
 
110
  buffer = io.BytesIO()
 
111
  img.save(buffer, format="JPEG", quality=85)
 
112
  return base64.b64encode(buffer.getvalue()).decode('utf-8')
113
  except Exception as e:
114
+ logger.warning(f"Image optimization failed: {e}")
115
  return b64_string
116
 
117
  def download_image_as_base64(self, url: str) -> Optional[str]:
118
+ """Download image from URL and return as base64 string."""
 
 
119
  try:
120
  response = requests.get(url, timeout=10, stream=True)
121
  response.raise_for_status()
 
 
122
  if len(response.content) > self.max_size:
 
123
  return None
 
 
124
  b64 = base64.b64encode(response.content).decode('utf-8')
125
  return self.optimize_base64(b64)
 
126
  except Exception as e:
127
  logger.error(f"Image download failed: {e}")
128
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_ocr_simple.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import os
4
+ sys.path.insert(0, os.getcwd())
5
+
6
+ import base64
7
+ import io
8
+ from PIL import Image, ImageDraw
9
+ from app.core.ocr import OCRProcessor
10
+
11
+ def create_test_image_b64(text: str) -> str:
12
+ img = Image.new('RGB', (400, 100), color=(255, 255, 255))
13
+ d = ImageDraw.Draw(img)
14
+ d.text((10, 40), text, fill=(0, 0, 0))
15
+ buffer = io.BytesIO()
16
+ img.save(buffer, format="PNG")
17
+ return base64.b64encode(buffer.getvalue()).decode()
18
+
19
+ def test_ocr_direct():
20
+ print("Initializing OCRProcessor...")
21
+ ocr = OCRProcessor()
22
+
23
+ text = "Hello OCR World"
24
+ b64 = create_test_image_b64(text)
25
+
26
+ print(f"Extracting text from image with '{text}'...")
27
+ result = ocr.extract_text(image_data=b64)
28
+
29
+ print(f"Result: {result}")
30
+
31
+ if text in result:
32
+ print("SUCCESS: OCR worked correctly.")
33
+ else:
34
+ print("FAILURE: Text not found.")
35
+
36
+ if __name__ == "__main__":
37
+ test_ocr_direct()
tests/test_ocr_tool.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pytest
3
+ import asyncio
4
+ import base64
5
+ import io
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ from app.agents.adk_mathminds import MathMindsADKAgent
8
+
9
+ def create_test_image_b64(text: str) -> str:
10
+ """Creates a simple image with text and returns base64 string."""
11
+ img = Image.new('RGB', (400, 100), color=(255, 255, 255))
12
+ d = ImageDraw.Draw(img)
13
+ # default font or simple drawing
14
+ d.text((10, 40), text, fill=(0, 0, 0))
15
+
16
+ buffer = io.BytesIO()
17
+ img.save(buffer, format="PNG")
18
+ return base64.b64encode(buffer.getvalue()).decode()
19
+
20
+ @pytest.mark.asyncio
21
+ async def test_ocr_tool_usage():
22
+ """
23
+ Verifies that the ADK agent can use the 'read_image' tool to extract text.
24
+ """
25
+ agent = MathMindsADKAgent()
26
+
27
+ secret_text = "The secret number is 999."
28
+ image_b64 = create_test_image_b64(secret_text)
29
+
30
+ print("\n--- Starting OCR Tool Test ---")
31
+ print(f"Generated image with text: '{secret_text}'")
32
+
33
+ # Ask the agent to read it
34
+ # We specifically ask to "read the text" to encourage tool usage
35
+ # over just vision model (though both might work).
36
+ response = await agent.solve(
37
+ problem="What is the secret number written in this image? Use your read_image tool if needed.",
38
+ image_data=image_b64,
39
+ session_id="test_ocr_session",
40
+ user_id="test_user"
41
+ )
42
+
43
+ print(f"Agent Response: {response}")
44
+
45
+ assert "999" in response, f"Agent failed to extract the number. Response: {response}"
46
+ print("\nSUCCESS: OCR Tool verified!")
47
+
48
+ if __name__ == "__main__":
49
+ asyncio.run(test_ocr_tool_usage())