Varshithdharmajv commited on
Commit
dbd32c5
·
verified ·
1 Parent(s): 99f7550

Upload handwriting_transcriber.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handwriting_transcriber.py +155 -0
handwriting_transcriber.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Handwriting Transcriber Module
3
+ Wrapper for handwritten-math-transcription repository
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ import torch
9
+ from typing import Optional, Tuple
10
+
11
+ # Add handwritten-math-transcription to path
12
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'handwritten-math-transcription'))
13
+
14
+ try:
15
+ from model import Encoder, Decoder, Seq2Seq
16
+ from dataset.hme_ink import read_inkml_file
17
+ from utils import tokenize_latex
18
+ from corrector import correct_latex
19
+ from config import *
20
+ except ImportError:
21
+ Encoder = None
22
+ Decoder = None
23
+ Seq2Seq = None
24
+ read_inkml_file = None
25
+ tokenize_latex = None
26
+ correct_latex = None
27
+
28
+
29
+ class HandwritingTranscriber:
30
+ """
31
+ Handwriting transcriber for mathematical expressions.
32
+ Converts handwritten math (InkML format) to LaTeX.
33
+ """
34
+
35
+ def __init__(self,
36
+ model_path: str = None,
37
+ device: str = None,
38
+ use_corrector: bool = True):
39
+ """
40
+ Initialize handwriting transcriber.
41
+
42
+ Args:
43
+ model_path: Path to trained model checkpoint
44
+ device: Device to run model on ('cpu', 'cuda', 'mps')
45
+ use_corrector: Whether to use LLM corrector for post-processing
46
+ """
47
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
48
+ self.use_corrector = use_corrector
49
+ self.model = None
50
+ self.model_path = model_path
51
+
52
+ if model_path and os.path.exists(model_path):
53
+ self.load_model(model_path)
54
+
55
+ def load_model(self, model_path: str):
56
+ """
57
+ Load trained model from checkpoint.
58
+
59
+ Args:
60
+ model_path: Path to model checkpoint
61
+ """
62
+ if Seq2Seq is None:
63
+ raise ImportError("Handwriting transcription model not available")
64
+
65
+ try:
66
+ # Model architecture parameters (from config or defaults)
67
+ input_dim = 11
68
+ enc_hidden_dim = 256
69
+ dec_hidden_dim = 256
70
+ embed_dim = 128
71
+ output_dim = LATEX_VOCAB_SIZE if 'LATEX_VOCAB_SIZE' in globals() else 300
72
+ encoder_num_layers = 2
73
+ decoder_num_layers = 2
74
+
75
+ # Create model
76
+ encoder = Encoder(input_dim, enc_hidden_dim,
77
+ num_layers=encoder_num_layers, bidirectional=True)
78
+ decoder = Decoder(output_dim, embed_dim,
79
+ enc_hidden_dim, dec_hidden_dim,
80
+ num_layers=decoder_num_layers)
81
+ self.model = Seq2Seq(encoder, decoder, self.device).to(self.device)
82
+
83
+ # Load weights
84
+ checkpoint = torch.load(model_path, map_location=self.device)
85
+ self.model.load_state_dict(checkpoint)
86
+ self.model.eval()
87
+
88
+ print(f"Model loaded from {model_path}")
89
+ except Exception as e:
90
+ print(f"Error loading model: {e}")
91
+ self.model = None
92
+
93
+ def transcribe_inkml(self, inkml_path: str) -> Tuple[str, str]:
94
+ """
95
+ Transcribe an InkML file to LaTeX.
96
+
97
+ Args:
98
+ inkml_path: Path to InkML file
99
+
100
+ Returns:
101
+ Tuple of (predicted_latex, ground_truth_latex if available)
102
+ """
103
+ if self.model is None:
104
+ raise ValueError("Model not loaded. Please load a model first.")
105
+
106
+ if read_inkml_file is None:
107
+ raise ImportError("InkML reading functionality not available")
108
+
109
+ try:
110
+ # Read InkML file
111
+ strokes, ground_truth = read_inkml_file(inkml_path)
112
+
113
+ # Convert to model input format
114
+ # This is a simplified version - actual implementation would need
115
+ # proper feature extraction and tensor conversion
116
+ # For now, return placeholder
117
+ predicted_latex = "\\placeholder"
118
+
119
+ # Apply corrector if enabled
120
+ if self.use_corrector and correct_latex:
121
+ try:
122
+ predicted_latex = correct_latex(predicted_latex)
123
+ except Exception as e:
124
+ print(f"Corrector error: {e}")
125
+
126
+ return predicted_latex, ground_truth
127
+ except Exception as e:
128
+ print(f"Error transcribing InkML: {e}")
129
+ return "", ""
130
+
131
+ def transcribe_image(self, image_path: str) -> str:
132
+ """
133
+ Transcribe a handwritten math image to LaTeX.
134
+ Note: This is a placeholder - actual implementation would require
135
+ image preprocessing and conversion to InkML or direct image processing.
136
+
137
+ Args:
138
+ image_path: Path to image file
139
+
140
+ Returns:
141
+ Predicted LaTeX string
142
+ """
143
+ # This would require additional image processing
144
+ # For now, return placeholder
145
+ return "\\placeholder"
146
+
147
+ def is_model_loaded(self) -> bool:
148
+ """
149
+ Check if model is loaded.
150
+
151
+ Returns:
152
+ True if model is loaded, False otherwise
153
+ """
154
+ return self.model is not None
155
+