File size: 8,915 Bytes
b7393db
 
 
 
 
 
 
 
 
 
61fb24c
34059a0
 
 
61fb24c
34059a0
c8b2ed7
61fb24c
34059a0
 
 
 
 
61fb24c
 
 
 
c8b2ed7
61fb24c
 
 
 
 
 
 
c8b2ed7
b7393db
c8b2ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34059a0
b7393db
34059a0
61fb24c
 
34059a0
61fb24c
 
34059a0
61fb24c
 
34059a0
 
61fb24c
34059a0
 
61fb24c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34059a0
 
 
 
b7393db
61fb24c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7393db
c8b2ed7
 
61fb24c
c8b2ed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61fb24c
 
 
 
 
 
 
c8b2ed7
b7393db
34059a0
 
 
 
61fb24c
 
 
34059a0
 
 
 
 
 
 
 
b7393db
34059a0
 
 
 
61fb24c
34059a0
 
 
61fb24c
 
 
34059a0
 
 
 
 
 
 
b7393db
34059a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7393db
34059a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8b2ed7
61fb24c
34059a0
c8b2ed7
34059a0
 
 
 
 
 
 
 
 
 
 
 
c8b2ed7
34059a0
 
 
 
 
 
 
 
 
 
 
 
b7393db
34059a0
61fb24c
 
 
34059a0
 
b7393db
34059a0
 
 
 
b7393db
34059a0
 
 
 
b7393db
34059a0
 
 
61fb24c
34059a0
 
 
c8b2ed7
34059a0
b7393db
34059a0
 
 
 
 
 
 
 
 
61fb24c
 
b7393db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# --- CLEANUP CODE ADDED HERE ---
# This will run once when the Space starts up.
import os
print("🧹 Running cleanup command...")
os.system("huggingface-cli delete-cache && rm -rf ~/.cache/pip")
print("✅ Cleanup complete.")
# -----------------------------


"""Marathi Translation API with AI4Bharat IndicTrans2
High-quality translation service for English <-> Marathi
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from pathlib import Path
import json
import torch

# Initialize Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Global variables
model = None
tokenizer = None
MODEL_DIR = Path("models/indictrans2-en-mr")
translations_dict = None
device = "cuda" if torch.cuda.is_available() else "cpu"

# Language codes for IndicTrans2
LANG_CODE_MAP = {
    "en": "eng_Latn",
    "mr": "mar_Deva"
}


def load_translations_dict():
    """Load the custom translations dictionary"""
    global translations_dict
    
    if translations_dict is not None:
        return translations_dict
    
    dict_path = Path("translations_dict.json")
    if dict_path.exists():
        with open(dict_path, 'r', encoding='utf-8') as f:
            translations_dict = json.load(f)
        print("✓ Translations dictionary loaded")
    else:
        translations_dict = {"en_to_mr": {}, "mr_to_en": {}}
        print("⚠ Translations dictionary not found, using model only")
    
    return translations_dict


def load_model():
    """Load the IndicTrans2 translation model"""
    global model, tokenizer
    
    if model is not None and tokenizer is not None:
        return model, tokenizer
    
    print("Loading IndicTrans2 model...")
    print(f"Device: {device}")
    
    if not MODEL_DIR.exists():
        raise FileNotFoundError(f"Model not found at {MODEL_DIR}. Run: python scripts/download_indictrans2.py")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            str(MODEL_DIR),
            local_files_only=True,
            trust_remote_code=True
        )
        
        model = AutoModelForSeq2SeqLM.from_pretrained(
            str(MODEL_DIR),
            local_files_only=True,
            trust_remote_code=True
        ).to(device)
        
        model.eval()  # Set to evaluation mode
        
        print("✓ IndicTrans2 model loaded successfully")
        return model, tokenizer
    except Exception as e:
        print(f"✗ Failed to load model: {e}")
        raise


def translate_with_indictrans2(text, source_lang, target_lang):
    """
    Translate using IndicTrans2 model
    IndicTrans2 requires input format: "<src_lang> <tgt_lang> <text>"
    """
    if model is None or tokenizer is None:
        load_model()
    
    # Convert language codes
    src_code = LANG_CODE_MAP.get(source_lang, source_lang)
    tgt_code = LANG_CODE_MAP.get(target_lang, target_lang)
    
    # IndicTrans2 expects format: "<src_lang> <tgt_lang> <text>"
    input_text = f"{src_code} {tgt_code} {text}"
    
    # Tokenize
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)
    
    # Generate translation
    with torch.no_grad():
        generated_tokens = model.generate(
            **inputs,
            max_length=256,
            num_beams=4,
            num_return_sequences=1,
            early_stopping=True
        )
    
    # Decode
    translation = tokenizer.batch_decode(
        generated_tokens,
        skip_special_tokens=True
    )[0]
    
    # Clean up output
    translation = translation.strip()
    
    return translation


def translate_with_dict(text, source, target):
    """
    Translate using dictionary first, fallback to IndicTrans2
    Returns (translation, used_dict)
    """
    dict_data = load_translations_dict()
    
    # Normalize text for dictionary lookup
    text_lower = text.lower().strip()
    
    # Check dictionary first
    if source == "en" and target == "mr":
        if text_lower in dict_data["en_to_mr"]:
            return dict_data["en_to_mr"][text_lower], True
    elif source == "mr" and target == "en":
        if text in dict_data["mr_to_en"]:
            return dict_data["mr_to_en"][text], True
    
    # Fallback to IndicTrans2 model
    try:
        translated = translate_with_indictrans2(text, source, target)
        return translated, False
    except Exception as e:
        print(f"Translation error: {e}")
        return text, False


@app.route('/')
def home():
    """Home endpoint with API information"""
    return jsonify({
        "service": "Marathi Translation API (IndicTrans2)",
        "version": "2.0.0",
        "model": "AI4Bharat IndicTrans2",
        "languages": ["en", "mr"],
        "endpoints": {
            "translate": "/translate (POST)",
            "health": "/health (GET)",
            "languages": "/languages (GET)"
        }
    })


@app.route('/health')
def health():
    """Health check endpoint"""
    try:
        if model is None or tokenizer is None:
            load_model()
        return jsonify({
            "status": "healthy",
            "model": "IndicTrans2",
            "model_loaded": model is not None,
            "device": device
        }), 200
    except Exception as e:
        return jsonify({
            "status": "unhealthy",
            "error": str(e)
        }), 503


@app.route('/languages', methods=['GET'])
def languages():
    """Get supported languages"""
    return jsonify([
        {
            "code": "en",
            "name": "English",
            "targets": ["mr"]
        },
        {
            "code": "mr",
            "name": "Marathi",
            "targets": ["en"]
        }
    ])


@app.route('/translate', methods=['POST'])
def translate():
    """
    Translate text between English and Marathi
    
    Request body:
    {
        "q": "text to translate" or ["text1", "text2"],
        "source": "en" or "mr",
        "target": "mr" or "en"
    }
    """
    try:
        # Get request data
        data = request.get_json()
        
        if not data:
            return jsonify({"error": "No JSON data provided"}), 400
        
        # Extract parameters
        text = data.get('q')
        source = data.get('source', '').lower()
        target = data.get('target', '').lower()
        
        # Validate parameters
        if not text:
            return jsonify({"error": "Missing 'q' parameter"}), 400
        
        if not source or not target:
            return jsonify({"error": "Missing 'source' or 'target' parameter"}), 400
        
        if source not in ['en', 'mr'] or target not in ['en', 'mr']:
            return jsonify({"error": "Supported languages: 'en', 'mr'"}), 400
        
        if source == target:
            return jsonify({"error": "Source and target languages must be different"}), 400
        
        # Load model and dictionary if not loaded
        if model is None or tokenizer is None:
            load_model()
        load_translations_dict()
        
        # Handle batch translation
        is_batch = isinstance(text, list)
        texts = text if is_batch else [text]
        
        # Translate
        translations = []
        for t in texts:
            if not t or not t.strip():
                translations.append("")
                continue
            
            translated, from_dict = translate_with_dict(t, source, target)
            translations.append(translated)
        
        # Return response
        response = {
            "translatedText": translations if is_batch else translations[0],
            "detectedLanguage": {
                "confidence": 100,
                "language": source
            }
        }
        
        return jsonify(response), 200
    
    except Exception as e:
        print(f"Error in translate endpoint: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


@app.errorhandler(404)
def not_found(e):
    return jsonify({"error": "Endpoint not found"}), 404


@app.errorhandler(500)
def server_error(e):
    return jsonify({"error": "Internal server error"}), 500


if __name__ == '__main__':
    # Pre-load model
    print("=" * 60)
    print("Marathi Translation API (IndicTrans2)")
    print("=" * 60)
    
    try:
        load_translations_dict()
        load_model()
        
        print("\nStarting server...")
        
        # Get port from environment or use default (7860 for HF Spaces)
        port = int(os.environ.get('PORT', 7860))
        host = os.environ.get('HOST', '0.0.0.0')
        
        app.run(host=host, port=port, debug=False)
    except Exception as e:
        print(f"\n✗ Failed to start server: {e}")
        import traceback
        traceback.print_exc()
        exit(1)