vinay0123 commited on
Commit
88920b4
·
verified ·
1 Parent(s): a2bb3c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -108
app.py CHANGED
@@ -20,7 +20,6 @@ import os
20
  import time
21
  import json
22
  import io
23
- import pickle
24
 
25
  # Set PyTorch to use all available CPU threads
26
  torch.set_num_threads(os.cpu_count())
@@ -186,22 +185,6 @@ def load_model(model, path="gpt_model.pth"):
186
 
187
  load_model(model)
188
 
189
- # Generate Response
190
- def generate_response(model, query, max_length=200):
191
- model.eval()
192
- with torch.no_grad(): # Disable gradient tracking
193
- src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
194
- tgt = torch.tensor([[1]]).to(device) # < SOS >
195
-
196
- for _ in range(max_length):
197
- output = model(src, tgt)
198
- next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
199
- tgt = torch.cat([tgt, next_token], dim=1)
200
- if next_token.item() == 2: # <EOS>
201
- break
202
-
203
- return tokenizer.decode(tgt.squeeze(0).tolist())
204
-
205
  # Translation model parameters
206
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
207
  MAX_LEN = 350
@@ -243,18 +226,6 @@ def build_vocab(sentences, tokenizer, min_freq):
243
  vocab[word] = len(vocab)
244
  return vocab
245
 
246
- # Save and load vocabulary functions
247
- def save_vocab(vocab, path):
248
- with open(path, 'wb') as f:
249
- pickle.dump(vocab, f)
250
-
251
- def load_vocab(path):
252
- try:
253
- with open(path, 'rb') as f:
254
- return pickle.load(f)
255
- except:
256
- return None
257
-
258
  # ==== Dataset ====
259
  class TranslationDataset(Dataset):
260
  def __init__(self, df, en_vocab, te_vocab):
@@ -318,83 +289,79 @@ def translate(model, sentence, en_vocab, te_vocab, te_inv_vocab, max_len=MAX_LEN
318
  translated = [te_inv_vocab[idx.item()] for idx in tgt_ids[0][1:]]
319
  return ' '.join(translated[:-1]) if translated[-1] == '<eos>' else ' '.join(translated)
320
 
321
- # ==== Load Translation Data and Vocabularies ====
322
- try:
323
- df_telugu = pd.read_csv("merged_translated_responses.csv")
324
- df_telugu = df_telugu.dropna(subset=['response', 'translated_response'])
325
- df_telugu['response'] = df_telugu['response'].astype(str)
326
- df_telugu['translated_response'] = df_telugu['translated_response'].astype(str)
327
-
328
- # Try to load saved vocabularies first
329
- en_vocab = load_vocab('en_vocab.pkl')
330
- te_vocab = load_vocab('te_vocab.pkl')
331
-
332
- if en_vocab is None or te_vocab is None:
333
- print("Building new vocabularies...")
334
- # Build vocabularies
335
- en_vocab = build_vocab(df_telugu['response'], tokenize_en, MIN_FREQ)
336
- te_vocab = build_vocab(df_telugu['translated_response'], tokenize_te, MIN_FREQ)
337
- # Save vocabularies
338
- save_vocab(en_vocab, 'en_vocab.pkl')
339
- save_vocab(te_vocab, 'te_vocab.pkl')
340
- else:
341
- print("Loaded saved vocabularies")
342
-
343
- te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
344
-
345
- print(f"Telugu translation dataset loaded successfully")
346
- print(f"English vocab size: {len(en_vocab)}, Telugu vocab size: {len(te_vocab)}")
347
- translation_available = True
348
- except Exception as e:
349
- print(f"Error loading Telugu dataset: {e}")
350
- # Create dummy vocabularies
351
- en_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3, 'hello': 4, 'world': 5}
352
- te_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3, 'హలో': 4, 'ప్రపంచం': 5}
353
- te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
354
- translation_available = False
355
-
356
- # Initialize Translation Model with correct vocabulary sizes
357
- model_telugu = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
358
- len(en_vocab), len(te_vocab), NHEAD, FFN_HID_DIM).to(DEVICE)
359
-
360
- # Load saved weights for translation model
361
- def load_telugu_model():
362
- model_path = "english_telugu_transformer.pth"
363
- if not os.path.exists(model_path):
364
- print("Telugu model file not found!")
365
- return False
366
-
367
  try:
368
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
 
369
 
370
- # Check if vocabulary sizes match
371
  if 'src_tok_emb.weight' in checkpoint:
372
  saved_en_vocab_size = checkpoint['src_tok_emb.weight'].shape[0]
373
  saved_te_vocab_size = checkpoint['tgt_tok_emb.weight'].shape[0]
374
- current_en_vocab_size = len(en_vocab)
375
- current_te_vocab_size = len(te_vocab)
376
 
377
  print(f"Saved model vocabs - EN: {saved_en_vocab_size}, TE: {saved_te_vocab_size}")
378
- print(f"Current model vocabs - EN: {current_en_vocab_size}, TE: {current_te_vocab_size}")
379
 
380
- if saved_en_vocab_size != current_en_vocab_size or saved_te_vocab_size != current_te_vocab_size:
381
- print("Vocabulary size mismatch! Creating new model with saved vocabulary sizes...")
382
- global model_telugu
383
- model_telugu = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
384
- saved_en_vocab_size, saved_te_vocab_size, NHEAD, FFN_HID_DIM).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
- model_telugu.load_state_dict(checkpoint)
387
- model_telugu.eval()
388
- print("Telugu translation model loaded successfully")
389
- return True
390
  except Exception as e:
391
  print(f"Error loading Telugu translation model: {e}")
392
- return False
393
-
394
- # Load Telugu model
395
- telugu_model_loaded = load_telugu_model()
396
- if not telugu_model_loaded:
397
- translation_available = False
398
 
399
  # Flask App
400
  app = Flask(__name__)
@@ -405,20 +372,9 @@ def home():
405
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
406
  return jsonify({"message": f"Welcome to TRAVIS API, Time : {current_time}"})
407
 
408
- @app.route("/intent")
409
- def intents():
410
- try:
411
- if 'intent' in df.columns:
412
- unique_intents = list(set(df['intent'].dropna()))
413
- else:
414
- unique_intents = ["general"] # fallback
415
- return jsonify({"intents": unique_intents})
416
- except Exception as e:
417
- return jsonify({"error": str(e), "intents": ["general"]}), 500
418
-
419
  @app.route("/translate", methods=["POST"])
420
  def translate_text():
421
- if not translation_available:
422
  return jsonify({"error": "Translation service not available"}), 503
423
 
424
  data = request.get_json()
@@ -546,7 +502,7 @@ def query_model():
546
  yield f"data: {json.dumps(response_data)}\n\n"
547
 
548
  # Translate to Telugu if available
549
- if translation_available:
550
  english_response = " ".join(english_words)
551
  telugu_response = translate(model_telugu, english_response, en_vocab, te_vocab, te_inv_vocab)
552
 
 
20
  import time
21
  import json
22
  import io
 
23
 
24
  # Set PyTorch to use all available CPU threads
25
  torch.set_num_threads(os.cpu_count())
 
185
 
186
  load_model(model)
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  # Translation model parameters
189
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
190
  MAX_LEN = 350
 
226
  vocab[word] = len(vocab)
227
  return vocab
228
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # ==== Dataset ====
230
  class TranslationDataset(Dataset):
231
  def __init__(self, df, en_vocab, te_vocab):
 
289
  translated = [te_inv_vocab[idx.item()] for idx in tgt_ids[0][1:]]
290
  return ' '.join(translated[:-1]) if translated[-1] == '<eos>' else ' '.join(translated)
291
 
292
+ # Initialize vocabularies from model checkpoint
293
+ translation_available = False
294
+ telugu_model_loaded = False
295
+ en_vocab = None
296
+ te_vocab = None
297
+ te_inv_vocab = None
298
+ model_telugu = None
299
+
300
+ # Load translation model and extract vocabularies
301
+ model_path = "english_telugu_transformer.pth"
302
+ if os.path.exists(model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  try:
304
+ print("Loading Telugu translation model...")
305
+ checkpoint = torch.load(model_path, map_location='cpu')
306
 
307
+ # Extract vocabulary sizes from the saved model
308
  if 'src_tok_emb.weight' in checkpoint:
309
  saved_en_vocab_size = checkpoint['src_tok_emb.weight'].shape[0]
310
  saved_te_vocab_size = checkpoint['tgt_tok_emb.weight'].shape[0]
 
 
311
 
312
  print(f"Saved model vocabs - EN: {saved_en_vocab_size}, TE: {saved_te_vocab_size}")
 
313
 
314
+ # Create model with correct vocabulary sizes
315
+ model_telugu = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
316
+ saved_en_vocab_size, saved_te_vocab_size, NHEAD, FFN_HID_DIM).to(DEVICE)
317
+
318
+ model_telugu.load_state_dict(checkpoint)
319
+ model_telugu.eval()
320
+
321
+ # Try to load translation data to build vocabularies
322
+ try:
323
+ df_telugu = pd.read_csv("merged_translated_responses.csv")
324
+ df_telugu = df_telugu.dropna(subset=['response', 'translated_response'])
325
+ df_telugu['response'] = df_telugu['response'].astype(str)
326
+ df_telugu['translated_response'] = df_telugu['translated_response'].astype(str)
327
+
328
+ print("Building vocabularies from data...")
329
+ en_vocab = build_vocab(df_telugu['response'], tokenize_en, MIN_FREQ)
330
+ te_vocab = build_vocab(df_telugu['translated_response'], tokenize_te, MIN_FREQ)
331
+ te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
332
+
333
+ # Check if vocabulary sizes match
334
+ if len(en_vocab) == saved_en_vocab_size and len(te_vocab) == saved_te_vocab_size:
335
+ translation_available = True
336
+ telugu_model_loaded = True
337
+ print(f"Telugu translation model loaded successfully")
338
+ print(f"English vocab size: {len(en_vocab)}, Telugu vocab size: {len(te_vocab)}")
339
+ else:
340
+ print(f"Vocabulary size mismatch - Data EN: {len(en_vocab)}, TE: {len(te_vocab)}")
341
+ print("Creating placeholder vocabularies...")
342
+ # Create vocabularies with correct sizes
343
+ en_vocab = {f'word_{i}': i for i in range(saved_en_vocab_size)}
344
+ te_vocab = {f'word_{i}': i for i in range(saved_te_vocab_size)}
345
+ te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
346
+ translation_available = True
347
+ telugu_model_loaded = True
348
+
349
+ except Exception as e:
350
+ print(f"Error loading Telugu dataset: {e}")
351
+ print("Creating placeholder vocabularies...")
352
+ # Create placeholder vocabularies with correct sizes
353
+ en_vocab = {f'word_{i}': i for i in range(saved_en_vocab_size)}
354
+ te_vocab = {f'word_{i}': i for i in range(saved_te_vocab_size)}
355
+ te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
356
+ translation_available = True
357
+ telugu_model_loaded = True
358
 
 
 
 
 
359
  except Exception as e:
360
  print(f"Error loading Telugu translation model: {e}")
361
+ translation_available = False
362
+ telugu_model_loaded = False
363
+ else:
364
+ print("Telugu model file not found!")
 
 
365
 
366
  # Flask App
367
  app = Flask(__name__)
 
372
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
373
  return jsonify({"message": f"Welcome to TRAVIS API, Time : {current_time}"})
374
 
 
 
 
 
 
 
 
 
 
 
 
375
  @app.route("/translate", methods=["POST"])
376
  def translate_text():
377
+ if not translation_available or not telugu_model_loaded:
378
  return jsonify({"error": "Translation service not available"}), 503
379
 
380
  data = request.get_json()
 
502
  yield f"data: {json.dumps(response_data)}\n\n"
503
 
504
  # Translate to Telugu if available
505
+ if translation_available and telugu_model_loaded:
506
  english_response = " ".join(english_words)
507
  telugu_response = translate(model_telugu, english_response, en_vocab, te_vocab, te_inv_vocab)
508