Samarth Naik commited on
Commit
a3debee
·
1 Parent(s): 5736c30

added texteller

Browse files
Files changed (5) hide show
  1. .dockerignore +15 -0
  2. Dockerfile +52 -0
  3. README.md +121 -0
  4. main.py +124 -279
  5. requirements.txt +4 -0
.dockerignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINTED_TEX_230k/
2
+ *.pyc
3
+ __pycache__/
4
+ .git/
5
+ .gitignore
6
+ *.md
7
+ .DS_Store
8
+ .env
9
+ .venv/
10
+ venv/
11
+ *.log
12
+ *.tmp
13
+ .pytest_cache/
14
+ .coverage
15
+ htmlcov/
Dockerfile ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python slim image as base
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies needed for TextTeller and image processing
8
+ RUN apt-get update && apt-get install -y \
9
+ git \
10
+ wget \
11
+ curl \
12
+ build-essential \
13
+ libgl1-mesa-glx \
14
+ libglib2.0-0 \
15
+ libsm6 \
16
+ libxext6 \
17
+ libxrender-dev \
18
+ libgomp1 \
19
+ && rm -rf /var/lib/apt/lists/*
20
+
21
+ # Install uv for faster Python package management
22
+ RUN pip install --no-cache-dir uv
23
+
24
+ # Copy requirements first for better caching
25
+ COPY requirements.txt .
26
+
27
+ # Install Python dependencies using uv
28
+ RUN uv pip install --system -r requirements.txt
29
+
30
+ # Install TextTeller with ONNX runtime support
31
+ RUN uv pip install --system texteller[onnxruntime-gpu]
32
+
33
+ # Copy the application code
34
+ COPY main.py .
35
+
36
+ # Create directory for temporary files
37
+ RUN mkdir -p /tmp/image_uploads
38
+
39
+ # Set environment variables
40
+ ENV FLASK_APP=main.py
41
+ ENV PYTHONUNBUFFERED=1
42
+ ENV PORT=5000
43
+
44
+ # Expose the port the app runs on
45
+ EXPOSE 5000
46
+
47
+ # Add healthcheck
48
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
49
+ CMD curl -f http://localhost:5000/health || exit 1
50
+
51
+ # Run the application
52
+ CMD ["python", "main.py"]
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image to LaTeX API
2
+
3
+ A Flask API that converts mathematical formula images to LaTeX code using TextTeller.
4
+
5
+ ## Features
6
+
7
+ - **POST /itl**: Upload an image and get LaTeX code back
8
+ - **GET /health**: Health check endpoint
9
+ - **GET /**: API documentation
10
+ - Automatic image cleanup after processing
11
+ - Support for multiple image formats (PNG, JPG, JPEG, GIF, BMP, TIFF)
12
+ - Maximum file size: 16MB
13
+
14
+ ## Quick Start
15
+
16
+ ### Using Docker (Recommended)
17
+
18
+ 1. Build the Docker image:
19
+ ```bash
20
+ docker build -t image-to-latex-api .
21
+ ```
22
+
23
+ 2. Run the container:
24
+ ```bash
25
+ docker run -p 5000:5000 image-to-latex-api
26
+ ```
27
+
28
+ ### Local Development
29
+
30
+ 1. Install dependencies:
31
+ ```bash
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ 2. Run the Flask app:
36
+ ```bash
37
+ python main.py
38
+ ```
39
+
40
+ ## API Usage
41
+
42
+ ### Convert Image to LaTeX
43
+
44
+ **Endpoint:** `POST /itl`
45
+
46
+ **Request:** Send image as multipart/form-data with key `"image"`
47
+
48
+ **Example using curl:**
49
+ ```bash
50
+ curl -X POST -F "image=@formula.png" http://localhost:5000/itl
51
+ ```
52
+
53
+ **Example using Python requests:**
54
+ ```python
55
+ import requests
56
+
57
+ with open('formula.png', 'rb') as f:
58
+ response = requests.post('http://localhost:5000/itl', files={'image': f})
59
+ print(response.json())
60
+ ```
61
+
62
+ **Response:**
63
+ ```json
64
+ {
65
+ "success": true,
66
+ "latex": "\\frac{x^2 + y^2}{2}"
67
+ }
68
+ ```
69
+
70
+ ### Health Check
71
+
72
+ **Endpoint:** `GET /health`
73
+
74
+ **Response:**
75
+ ```json
76
+ {
77
+ "status": "healthy",
78
+ "texteller_available": true
79
+ }
80
+ ```
81
+
82
+ ## Deployment
83
+
84
+ ### Hugging Face Spaces
85
+
86
+ This API can be easily deployed to Hugging Face Spaces:
87
+
88
+ 1. Create a new Space on Hugging Face
89
+ 2. Upload the files: `main.py`, `Dockerfile`, `requirements.txt`
90
+ 3. Set the Space to use Docker
91
+ 4. The API will be available at your Space URL
92
+
93
+ ### Other Platforms
94
+
95
+ The Dockerfile is compatible with most container platforms like:
96
+ - Google Cloud Run
97
+ - AWS ECS/Fargate
98
+ - Azure Container Instances
99
+ - Railway
100
+ - Render
101
+
102
+ ## Error Handling
103
+
104
+ The API returns appropriate HTTP status codes:
105
+ - `200`: Success
106
+ - `400`: Bad request (no image, invalid format)
107
+ - `408`: Request timeout (processing took too long)
108
+ - `413`: Payload too large (file > 16MB)
109
+ - `500`: Internal server error
110
+ - `503`: Service unavailable
111
+
112
+ ## Environment Variables
113
+
114
+ - `PORT`: Port to run the Flask app on (default: 5000)
115
+
116
+ ## Notes
117
+
118
+ - Images are automatically deleted after processing to save disk space
119
+ - Processing timeout is set to 30 seconds
120
+ - The API uses TextTeller's inference capabilities under the hood
121
+ - Temporary files are created in the system's temp directory
main.py CHANGED
@@ -1,282 +1,127 @@
1
- # ----------------------------------------------------
2
- # Base image with CUDA + PyTorch
3
- # ----------------------------------------------------
4
- FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
5
-
6
- # ----------------------------------------------------
7
- # System dependencies
8
- # ----------------------------------------------------
9
- RUN apt-get update && apt-get install -y \
10
- python3 python3-pip python3-dev \
11
- git wget nano unzip findutils \
12
- && rm -rf /var/lib/apt/lists/*
13
-
14
- # ----------------------------------------------------
15
- # Install Python dependencies
16
- # ----------------------------------------------------
17
- RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121
18
- RUN pip3 install kagglehub pillow tqdm
19
-
20
- # ----------------------------------------------------
21
- # Set working directory
22
- # ----------------------------------------------------
23
- WORKDIR /workspace
24
-
25
- # ----------------------------------------------------
26
- # Write FULL main.py directly into the container
27
- # ----------------------------------------------------
28
- RUN cat << 'EOF' > /workspace/main.py
29
  import os
30
- import argparse
 
31
  from pathlib import Path
32
- import torch
33
- from torch import nn
34
- from torch.utils.data import Dataset, DataLoader
35
- from torchvision import transforms, models
36
- from torch.optim import AdamW
37
- from torch.cuda.amp import autocast, GradScaler
38
- import torch.nn.functional as F
39
- from PIL import Image
40
- from tqdm import tqdm
41
-
42
- # ============================================================
43
- # TOKENIZER
44
- # ============================================================
45
- class CharTokenizer:
46
- def __init__(self):
47
- self.special = ["<pad>", "<sos>", "<eos>", "<unk>"]
48
- self.idx2tok = list(self.special)
49
- self.tok2idx = {t: i for i, t in enumerate(self.idx2tok)}
50
-
51
- def build(self, formulas):
52
- chars = set()
53
- for f in formulas:
54
- chars.update(list(f))
55
- for c in sorted(chars):
56
- if c not in self.tok2idx:
57
- self.tok2idx[c] = len(self.idx2tok)
58
- self.idx2tok.append(c)
59
-
60
- def encode(self, text):
61
- return [self.tok2idx.get(t, self.tok2idx["<unk>"]) for t in ["<sos>"] + list(text) + ["<eos>"]]
62
-
63
- @property
64
- def pad(self):
65
- return self.tok2idx["<pad>"]
66
-
67
- def __len__(self):
68
- return len(self.idx2tok)
69
-
70
-
71
- # ============================================================
72
- # DATASET
73
- # ============================================================
74
- class TexDataset(Dataset):
75
- def __init__(self, filenames, formulas, root):
76
- self.filenames = filenames
77
- self.formulas = formulas
78
- self.root = Path(root) / "generated_png_images"
79
-
80
- self.transform = transforms.Compose([
81
- transforms.Resize((256, 1024)),
82
- transforms.ToTensor(),
83
- transforms.Normalize([0.5], [0.5]),
84
- ])
85
-
86
- def __len__(self):
87
- return len(self.filenames)
88
-
89
- def __getitem__(self, idx):
90
- img_path = self.root / self.filenames[idx]
91
- image = Image.open(img_path).convert("RGB")
92
- image = self.transform(image)
93
- return image, self.formulas[idx], self.filenames[idx]
94
-
95
-
96
- def collate_fn(batch, tokenizer):
97
- images, texts, names = zip(*batch)
98
- images = torch.stack(images)
99
-
100
- encoded = [torch.tensor(tokenizer.encode(t)) for t in texts]
101
- max_len = max(len(e) for e in encoded)
102
-
103
- padded = torch.full((len(encoded), max_len), tokenizer.pad, dtype=torch.long)
104
- for i, e in enumerate(encoded):
105
- padded[i, :len(e)] = e
106
-
107
- return images, padded, names
108
-
109
-
110
- # ============================================================
111
- # MODEL
112
- # ============================================================
113
- class Img2Latex(nn.Module):
114
- def __init__(self, vocab_size, d_model=512):
115
- super().__init__()
116
-
117
- resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
118
- self.encoder = nn.Sequential(*list(resnet.children())[:-2])
119
-
120
- self.proj = nn.Conv2d(512, d_model, 1)
121
- self.embed = nn.Embedding(vocab_size, d_model)
122
- self.pos = nn.Parameter(torch.randn(1, 1024, d_model))
123
-
124
- dec_layer = nn.TransformerDecoderLayer(d_model, 8)
125
- self.decoder = nn.TransformerDecoder(dec_layer, num_layers=4)
126
-
127
- self.fc = nn.Linear(d_model, vocab_size)
128
-
129
- def forward(self, images, tgt):
130
- feats = self.encoder(images)
131
- feats = self.proj(feats)
132
- feats = feats.mean(dim=2)
133
- feats = feats.permute(2, 0, 1)
134
-
135
- tgt = tgt.permute(1, 0)
136
- emb = self.embed(tgt)
137
- emb = emb + self.pos[:, :emb.size(0), :]
138
-
139
- mask = nn.Transformer.generate_square_subsequent_mask(emb.size(0)).to(emb.device)
140
-
141
- dec = self.decoder(emb, feats, tgt_mask=mask)
142
- return self.fc(dec)
143
-
144
-
145
- # ============================================================
146
- # TRAINING LOOP WITH LOGGING
147
- # ============================================================
148
- def train_epoch(model, loader, optimizer, scaler, tokenizer, device, epoch):
149
- model.train()
150
- total_loss = 0
151
- processed_images = 0
152
-
153
- progress = tqdm(loader, desc=f"Epoch {epoch} Training", unit="batch")
154
-
155
- for images, tgt, batch_filenames in progress:
156
- images, tgt = images.to(device), tgt.to(device)
157
-
158
- processed_images += len(batch_filenames)
159
-
160
- print("\n🖼️ Processing batch images:")
161
- for name in batch_filenames:
162
- print(" -", name)
163
- print(f"📊 Processed {processed_images} / {len(loader.dataset)} images\n")
164
-
165
- optimizer.zero_grad()
166
-
167
- with autocast():
168
- logits = model(images, tgt)
169
- logits = logits.permute(1, 0, 2)
170
-
171
- loss = F.cross_entropy(
172
- logits.reshape(-1, logits.size(-1)),
173
- tgt.reshape(-1),
174
- ignore_index=tokenizer.pad
175
  )
176
-
177
- scaler.scale(loss).backward()
178
- scaler.step(optimizer)
179
- scaler.update()
180
-
181
- total_loss += loss.item()
182
- progress.set_postfix({"loss": loss.item()})
183
-
184
- return total_loss / len(loader)
185
-
186
-
187
- # ============================================================
188
- # MAIN
189
- # ============================================================
190
- def main():
191
- parser = argparse.ArgumentParser()
192
- parser.add_argument("--data_dir", type=str, default="PRINTED_TEX_230k")
193
- parser.add_argument("--batch", type=int, default=8)
194
- parser.add_argument("--epochs", type=int, default=5)
195
- args = parser.parse_args()
196
-
197
- root = Path(args.data_dir)
198
-
199
- image_files = open(root / "corresponding_png_images.txt").read().splitlines()
200
- formulas = open(root / "final_png_formulas.txt").read().splitlines()
201
-
202
- n = min(len(image_files), len(formulas))
203
- image_files = image_files[:n]
204
- formulas = formulas[:n]
205
-
206
- print(f"Loaded {n} image-formula pairs")
207
-
208
- tokenizer = CharTokenizer()
209
- tokenizer.build(formulas)
210
-
211
- dataset = TexDataset(image_files, formulas, root)
212
- loader = DataLoader(
213
- dataset,
214
- batch_size=args.batch,
215
- shuffle=True,
216
- collate_fn=lambda b: collate_fn(b, tokenizer),
217
- num_workers=2
218
- )
219
-
220
- device = "cuda" if torch.cuda.is_available() else "cpu"
221
- print("Using device:", device)
222
-
223
- model = Img2Latex(len(tokenizer)).to(device)
224
- optimizer = AdamW(model.parameters(), lr=3e-4)
225
- scaler = GradScaler()
226
-
227
- for epoch in range(1, args.epochs + 1):
228
- loss = train_epoch(model, loader, optimizer, scaler, tokenizer, device, epoch)
229
- print(f"Epoch {epoch} complete — Loss: {loss:.4f}")
230
-
231
- torch.save(model.state_dict(), f"model_epoch{epoch}.pth")
232
-
233
- print("Training complete!")
234
-
235
-
236
- if __name__ == "__main__":
237
- main()
238
- EOF
239
-
240
- # ----------------------------------------------------
241
- # ALWAYS DOWNLOAD DATASET FROM KAGGLEHUB AND BUILD STRUCTURE
242
- # ----------------------------------------------------
243
- RUN python3 - << 'EOF'
244
- import kagglehub, os, shutil
245
-
246
- print("\n⬇️ Downloading dataset from KaggleHub...\n")
247
- download_path = kagglehub.dataset_download("gregoryeritsyan/im2latex-230k")
248
- print("📥 Downloaded to:", download_path)
249
-
250
- # Create final dataset structure
251
- target = "/workspace/PRINTED_TEX_230k"
252
- if os.path.exists(target):
253
- shutil.rmtree(target)
254
- os.makedirs(target + "/generated_png_images", exist_ok=True)
255
-
256
- # Mapping from KaggleHub structure to your structure
257
- mapping = {
258
- "formulas.txt": "final_png_formulas.txt",
259
- "formula_images.txt": "corresponding_png_images.txt",
260
- "vocab.json": "230k.json",
261
- }
262
-
263
- # Move text files
264
- for src, dst in mapping.items():
265
- src_path = os.path.join(download_path, src)
266
- if os.path.exists(src_path):
267
- shutil.move(src_path, os.path.join(target, dst))
268
- print(f"✔ Mapped {src} → {dst}")
269
-
270
- # Move image directory
271
- images_src = os.path.join(download_path, "images")
272
- if os.path.exists(images_src):
273
- shutil.move(images_src, os.path.join(target, "generated_png_images"))
274
- print("✔ Mapped images/ → generated_png_images/")
275
-
276
- print("\n🎉 Dataset prepared at:", target)
277
- EOF
278
-
279
- # ----------------------------------------------------
280
- # Run training by default
281
- # ----------------------------------------------------
282
- CMD ["python3", "main.py", "--data_dir", "PRINTED_TEX_230k", "--epochs", "5"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import tempfile
3
+ import subprocess
4
  from pathlib import Path
5
+ from flask import Flask, request, jsonify
6
+ import logging
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ app = Flask(__name__)
13
+
14
+ # Configure maximum file size (16MB)
15
+ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
16
+
17
+ @app.route('/itl', methods=['POST'])
18
+ def image_to_latex():
19
+ """Convert uploaded image to LaTeX code using TextTeller."""
20
+ try:
21
+ # Check if image file is present
22
+ if 'image' not in request.files:
23
+ return jsonify({'error': 'No image file provided'}), 400
24
+
25
+ file = request.files['image']
26
+ if file.filename == '':
27
+ return jsonify({'error': 'No file selected'}), 400
28
+
29
+ # Validate file type
30
+ allowed_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff'}
31
+ file_ext = file.filename.rsplit('.', 1)[-1].lower() if '.' in file.filename else ''
32
+ if file_ext not in allowed_extensions:
33
+ return jsonify({'error': f'Invalid file type. Allowed: {", ".join(allowed_extensions)}'}), 400
34
+
35
+ # Create temporary file to save uploaded image
36
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_ext}') as tmp_file:
37
+ file.save(tmp_file.name)
38
+ temp_image_path = tmp_file.name
39
+
40
+ try:
41
+ # Run texteller inference command
42
+ logger.info(f"Processing image: {temp_image_path}")
43
+ result = subprocess.run(
44
+ ['texteller', 'inference', temp_image_path],
45
+ capture_output=True,
46
+ text=True,
47
+ timeout=30 # 30 second timeout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
+
50
+ if result.returncode == 0:
51
+ # Extract LaTeX from output
52
+ latex_output = result.stdout.strip()
53
+ logger.info(f"Successfully processed image. LaTeX length: {len(latex_output)}")
54
+ return jsonify({
55
+ 'success': True,
56
+ 'latex': latex_output
57
+ })
58
+ else:
59
+ logger.error(f"TextTeller inference failed: {result.stderr}")
60
+ return jsonify({
61
+ 'error': 'Failed to process image',
62
+ 'details': result.stderr
63
+ }), 500
64
+
65
+ except subprocess.TimeoutExpired:
66
+ logger.error("TextTeller inference timed out")
67
+ return jsonify({'error': 'Processing timed out'}), 408
68
+
69
+ except Exception as e:
70
+ logger.error(f"Error during processing: {str(e)}")
71
+ return jsonify({'error': f'Processing error: {str(e)}'}), 500
72
+
73
+ finally:
74
+ # Clean up temporary file
75
+ try:
76
+ os.unlink(temp_image_path)
77
+ logger.info(f"Cleaned up temporary file: {temp_image_path}")
78
+ except OSError as e:
79
+ logger.warning(f"Failed to delete temporary file {temp_image_path}: {e}")
80
+
81
+ except Exception as e:
82
+ logger.error(f"Unexpected error: {str(e)}")
83
+ return jsonify({'error': f'Unexpected error: {str(e)}'}), 500
84
+
85
+ @app.route('/health', methods=['GET'])
86
+ def health_check():
87
+ """Health check endpoint."""
88
+ try:
89
+ # Test if texteller is available
90
+ result = subprocess.run(['texteller', '--help'], capture_output=True, timeout=5)
91
+ texteller_available = result.returncode == 0
92
+
93
+ return jsonify({
94
+ 'status': 'healthy',
95
+ 'texteller_available': texteller_available
96
+ })
97
+ except Exception as e:
98
+ return jsonify({
99
+ 'status': 'unhealthy',
100
+ 'error': str(e)
101
+ }), 503
102
+
103
+ @app.route('/', methods=['GET'])
104
+ def root():
105
+ """Root endpoint with API documentation."""
106
+ return jsonify({
107
+ 'service': 'Image to LaTeX API',
108
+ 'version': '1.0.0',
109
+ 'endpoints': {
110
+ 'POST /itl': 'Convert image to LaTeX. Send image file as multipart/form-data with key "image"',
111
+ 'GET /health': 'Health check endpoint',
112
+ 'GET /': 'This documentation'
113
+ },
114
+ 'supported_formats': ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff'],
115
+ 'max_file_size': '16MB'
116
+ })
117
+
118
+ if __name__ == '__main__':
119
+ # Check if texteller is installed
120
+ try:
121
+ result = subprocess.run(['texteller', '--help'], capture_output=True)
122
+ if result.returncode != 0:
123
+ logger.warning("TextTeller might not be properly installed")
124
+ except FileNotFoundError:
125
+ logger.error("TextTeller is not installed. Please install it with: pip install texteller")
126
+
127
+ app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 5000)), debug=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ flask>=2.3.0
2
+ texteller
3
+ pillow>=10.0.0
4
+ gunicorn>=21.2.0