vinay0123 commited on
Commit
1dd4ebf
·
verified ·
1 Parent(s): 6a3e700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -137
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import textwrap
3
  import torch
4
  from datetime import datetime
@@ -10,29 +9,47 @@ import pandas as pd
10
  from torch.utils.data import Dataset, DataLoader
11
  from torch.nn.utils.rnn import pad_sequence
12
  from sklearn.model_selection import train_test_split
13
- from flask import Flask ,request, jsonify,send_file,after_this_request
14
  from collections import Counter
15
  from flask_cors import CORS
16
  import requests
17
  from gtts import gTTS
18
- from googletrans import Translator
19
  import uuid
20
  import os
21
  import time
 
 
 
 
 
 
22
 
 
 
23
 
24
  # Load Dataset
25
- df = pd.read_csv("https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY")
26
- df = df.dropna(subset=['instruction', 'response'])
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Ensure all entries are strings
29
- df['instruction'] = df['instruction'].astype(str)
30
- df['response'] = df['response'].astype(str)
31
  # Tokenizer (Scratch)
32
  class ScratchTokenizer:
33
  def __init__(self):
34
- self.word2idx = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
35
- self.idx2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
36
  self.vocab_size = 4
37
 
38
  def build_vocab(self, texts):
@@ -81,40 +98,61 @@ test_dataset = TextDataset(test_data, tokenizer)
81
  train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
82
  test_loader = DataLoader(test_dataset, batch_size=8)
83
 
84
- # Improved GPT-Style Transformer Model
85
-
86
  class GPTModel(nn.Module):
87
  def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200):
88
  super(GPTModel, self).__init__()
89
  self.embedding = nn.Embedding(vocab_size, embed_size)
90
  self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size))
91
- # The problem was here, setting num_encoder_layers to 0
92
- # makes the model try to access a non-existent layer.
93
- # The solution is to remove the encoder completely.
94
- self.transformer = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads), num_layers=num_layers)
 
 
 
 
 
 
95
  self.fc_out = nn.Linear(embed_size, vocab_size)
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def forward(self, src, tgt):
98
  src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :]
99
  tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :]
100
-
101
- # Causal Mask for Auto-Regressive Decoding
102
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
103
- output = self.transformer(tgt_emb.permute(1, 0, 2), src_emb.permute(1, 0, 2), tgt_mask=tgt_mask)
104
- return self.fc_out(output.permute(1, 0, 2))
105
 
106
- # Initialize Model
107
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
108
  model = GPTModel(tokenizer.vocab_size).to(device)
109
- optimizer = optim.AdamW(model.parameters(), lr=2e-4)
110
  criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
111
 
 
 
112
 
113
  def load_model(model, path="gpt_model.pth"):
114
  if os.path.exists(path):
115
- model.load_state_dict(torch.load(path, map_location=device))
116
- model.eval()
117
- print("Model loaded successfully.")
 
 
 
118
  else:
119
  print("Model file not found!")
120
 
@@ -125,7 +163,7 @@ def generate_response(model, query, max_length=200):
125
  model.eval()
126
  with torch.no_grad(): # Disable gradient tracking
127
  src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
128
- tgt = torch.tensor([[1]]).to(device) # <SOS>
129
 
130
  for _ in range(max_length):
131
  output = model(src, tgt)
@@ -136,7 +174,7 @@ def generate_response(model, query, max_length=200):
136
 
137
  return tokenizer.decode(tgt.squeeze(0).tolist())
138
 
139
-
140
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
141
  MAX_LEN = 350
142
  BATCH_SIZE = 8
@@ -149,9 +187,19 @@ NUM_EPOCHS = 18
149
  MIN_FREQ = 2
150
 
151
  # ==== Tokenizers ====
152
- spacy_eng = spacy.load("en_core_web_sm")
 
 
 
 
 
 
153
  def tokenize_en(text):
154
- return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
 
 
 
 
155
 
156
  def tokenize_te(text):
157
  return text.strip().split(" ")
@@ -189,13 +237,6 @@ class TranslationDataset(Dataset):
189
 
190
  return torch.tensor(en_ids), torch.tensor(te_ids)
191
 
192
- # ==== Collate Function ====
193
- def collate_fn(batch):
194
- src_batch, tgt_batch = zip(*batch)
195
- src_batch = pad_sequence(src_batch, padding_value=en_vocab['<pad>'], batch_first=True)
196
- tgt_batch = pad_sequence(tgt_batch, padding_value=te_vocab['<pad>'], batch_first=True)
197
- return src_batch, tgt_batch
198
-
199
  # ==== Transformer Model ====
200
  class Seq2SeqTransformer(nn.Module):
201
  def __init__(self, num_encoder_layers, num_decoder_layers,
@@ -237,144 +278,239 @@ def translate(model, sentence, en_vocab, te_vocab, te_inv_vocab, max_len=MAX_LEN
237
  translated = [te_inv_vocab[idx.item()] for idx in tgt_ids[0][1:]]
238
  return ' '.join(translated[:-1]) if translated[-1] == '<eos>' else ' '.join(translated)
239
 
240
- # ==== Load Data ====
241
- df_telugu = pd.read_csv("merged_translated_responses.csv") # columns: 'en', 'te'
242
- # Clean NaN or non-string entries
243
- df_telugu = df_telugu.dropna(subset=['response', 'translated_response'])
244
-
245
- # Ensure all entries are strings
246
- df_telugu['response'] = df_telugu['response'].astype(str)
247
- df_telugu['translated_response'] = df_telugu['translated_response'].astype(str)
248
-
249
- # Build vocabularies
250
- en_vocab = build_vocab(df_telugu['response'], tokenize_en, MIN_FREQ)
251
- te_vocab = build_vocab(df_telugu['translated_response'], tokenize_te, MIN_FREQ)
252
- te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
253
-
254
- # Prepare Dataset & DataLoader
255
- dataset = TranslationDataset(df_telugu, en_vocab, te_vocab)
256
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
257
-
258
- # Initialize Model
259
- # model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
260
- # len(en_vocab), len(te_vocab), NHEAD, FFN_HID_DIM).to(DEVICE)
261
-
262
- pad_idx = te_vocab['<pad>']
263
- criterion_telugu = nn.CrossEntropyLoss(ignore_index=pad_idx)
264
- optimizer_telugu = optim.Adam(model.parameters(), lr=0.0005)
265
-
266
- # ==== Training ====
267
- # for epoch in range(NUM_EPOCHS):
268
- # loss = train(model, dataloader, optimizer, criterion)
269
- # print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
270
-
271
- # ==== Try Translation ====
272
-
273
- model_telugu = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,len(en_vocab), len(te_vocab), NHEAD, FFN_HID_DIM).to(DEVICE)
274
-
275
- # Load saved weights
276
- model_telugu.load_state_dict(torch.load("english_telugu_transformer.pth",map_location = torch.device('cpu')))
277
- model_telugu.eval()
278
- app=Flask(__name__)
279
  CORS(app)
280
 
281
-
282
  @app.route("/")
283
  def home():
284
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
285
  return jsonify({"message": f"Welcome to TRAVIS API, Time : {current_time}"})
286
 
287
-
288
  @app.route("/intent")
289
  def intents():
290
- return jsonify({"intents" :list(set(df['intent'].dropna()))})
291
-
292
-
 
 
 
 
 
293
 
294
  @app.route("/translate", methods=["POST"])
295
  def translate_text():
 
 
 
296
  data = request.get_json()
297
  text = data.get("text", "")
298
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
299
- print("Entered '/translate' at time: ",current_time)
300
  if not text:
301
  return jsonify({"error": "Text cannot be empty"}), 400
302
 
303
- # First generate English response
304
- english_response = text
305
- start=time.time()
306
- # Then translate to Telugu
307
- telugu_response = translate(model_telugu, english_response, en_vocab, te_vocab, te_inv_vocab)
308
- end=time.time()
309
- return jsonify({
310
- "english": english_response,
311
- "telugu": telugu_response,
312
- "time": end-start
313
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  @app.route("/generate", methods=["POST"])
316
  def generate_text():
317
  data = request.get_json()
318
  query = data.get("query", "")
319
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
320
- print("Entered '/generate' at time: ",current_time)
321
-
322
  if not query:
323
  return jsonify({"error": "Query cannot be empty"}), 400
324
- start=time.time()
325
- response = generate_response(model, query)
326
- end=time.time()
327
- # Clean the response
328
- def clean_response(response):
329
- return response.replace("<EOS>", "").replace("<SOS>", "").strip()
330
-
331
- response = clean_response(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- return jsonify({
334
- "response": response,
335
- "time": end-start
336
- })
 
 
 
 
337
 
338
  @app.route("/query", methods=["POST"])
339
  def query_model():
340
- global audio_telugu_response
341
  data = request.get_json()
342
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
343
- print("Entered '/query' at time: ",current_time)
344
  query = data.get("query", "")
345
-
346
  if not query:
347
  return jsonify({"error": "Query cannot be empty"}), 400
348
-
349
- start_eng = time.time()
350
- # Assuming `generate_response` is a function that processes the query
351
- response = generate_response(model, query)
352
- end_eng = time.time()
353
- def clean_response(response):
354
- return response.replace("<EOS>", "").replace("<SOS>", "").strip()
355
- response=clean_response(response)
356
- start_te = time.time()
357
- telugu_response = translate(model_telugu, response, en_vocab, te_vocab, te_inv_vocab)
358
- end_te = time.time()
359
- audio_telugu_response=telugu_response
360
- return jsonify({"telugu":(telugu_response),"english":(response),"eng_time":(end_eng-start_eng),"telugu_time":(end_te-start_te)})
361
-
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  @app.route("/audio", methods=["POST"])
364
  def get_audio():
365
  data = request.get_json()
366
  text = data.get("text")
367
- start_te = time.time()
368
-
369
  if not text:
370
  return jsonify({"error": "No Response To convert to speech"}), 400
371
 
372
- # Convert text to Telugu speech using in-memory file
373
- speech = gTTS(text=text, lang="te")
374
- audio_io = io.BytesIO()
375
- speech.write_to_fp(audio_io)
376
- audio_io.seek(0)
377
- end_te = time.time()
378
- print("telugu_time: ",(end_te-start_te))
379
-
380
- return send_file(audio_io, mimetype="audio/mpeg", as_attachment=False)
 
 
 
 
 
 
 
 
 
 
 
1
  import textwrap
2
  import torch
3
  from datetime import datetime
 
9
  from torch.utils.data import Dataset, DataLoader
10
  from torch.nn.utils.rnn import pad_sequence
11
  from sklearn.model_selection import train_test_split
12
+ from flask import Flask, request, jsonify, send_file, after_this_request, Response, stream_with_context
13
  from collections import Counter
14
  from flask_cors import CORS
15
  import requests
16
  from gtts import gTTS
17
+
18
  import uuid
19
  import os
20
  import time
21
+ import json
22
+ import io
23
+
24
+ # Set PyTorch to use all available CPU threads
25
+ torch.set_num_threads(os.cpu_count())
26
+ torch.set_num_interop_threads(os.cpu_count())
27
 
28
+ # Enable PyTorch JIT for better performance
29
+ torch.jit.enable_onednn_fusion(True)
30
 
31
  # Load Dataset
32
+ try:
33
+ df = pd.read_csv("https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY")
34
+ df = df.dropna(subset=['instruction', 'response'])
35
+ # Ensure all entries are strings
36
+ df['instruction'] = df['instruction'].astype(str)
37
+ df['response'] = df['response'].astype(str)
38
+ print("Main dataset loaded successfully")
39
+ except Exception as e:
40
+ print(f"Error loading main dataset: {e}")
41
+ # Create a dummy dataset for testing
42
+ df = pd.DataFrame({
43
+ 'instruction': ['Hello', 'How are you?'],
44
+ 'response': ['Hi there!', 'I am doing well, thank you!'],
45
+ 'intent': ['greeting', 'greeting']
46
+ })
47
 
 
 
 
48
  # Tokenizer (Scratch)
49
  class ScratchTokenizer:
50
  def __init__(self):
51
+ self.word2idx = {"<PAD>": 0, "< SOS >": 1, "<EOS>": 2, "<UNK>": 3}
52
+ self.idx2word = {0: "<PAD>", 1: "< SOS >", 2: "<EOS>", 3: "<UNK>"}
53
  self.vocab_size = 4
54
 
55
  def build_vocab(self, texts):
 
98
  train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
99
  test_loader = DataLoader(test_dataset, batch_size=8)
100
 
101
+ # Improved GPT-Style Transformer Model with optimizations
 
102
  class GPTModel(nn.Module):
103
  def __init__(self, vocab_size, embed_size=256, num_heads=8, num_layers=6, max_len=200):
104
  super(GPTModel, self).__init__()
105
  self.embedding = nn.Embedding(vocab_size, embed_size)
106
  self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_size))
107
+ self.transformer = nn.TransformerDecoder(
108
+ nn.TransformerDecoderLayer(
109
+ d_model=embed_size,
110
+ nhead=num_heads,
111
+ dim_feedforward=embed_size * 4, # Increased feedforward dimension
112
+ dropout=0.1,
113
+ batch_first=True # Enable batch first for better performance
114
+ ),
115
+ num_layers=num_layers
116
+ )
117
  self.fc_out = nn.Linear(embed_size, vocab_size)
118
+
119
+ # Initialize weights for better training
120
+ self.apply(self._init_weights)
121
+
122
+ def _init_weights(self, module):
123
+ if isinstance(module, nn.Linear):
124
+ torch.nn.init.xavier_uniform_(module.weight)
125
+ if module.bias is not None:
126
+ module.bias.data.zero_()
127
+ elif isinstance(module, nn.Embedding):
128
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
129
 
130
  def forward(self, src, tgt):
131
  src_emb = self.embedding(src) + self.pos_embedding[:, :src.size(1), :]
132
  tgt_emb = self.embedding(tgt) + self.pos_embedding[:, :tgt.size(1), :]
 
 
133
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
134
+ output = self.transformer(tgt_emb, src_emb, tgt_mask=tgt_mask)
135
+ return self.fc_out(output)
136
 
137
+ # Initialize Model with optimizations
138
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
139
+ print(f"Using device: {device}")
140
+
141
  model = GPTModel(tokenizer.vocab_size).to(device)
142
+ optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01) # Added weight decay
143
  criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
144
 
145
+ # Remove JIT compilation as it can cause issues with dynamic models
146
+ # model = torch.jit.script(model) # Commented out
147
 
148
  def load_model(model, path="gpt_model.pth"):
149
  if os.path.exists(path):
150
+ try:
151
+ model.load_state_dict(torch.load(path, map_location=device))
152
+ model.eval()
153
+ print("Model loaded successfully.")
154
+ except Exception as e:
155
+ print(f"Error loading model: {e}")
156
  else:
157
  print("Model file not found!")
158
 
 
163
  model.eval()
164
  with torch.no_grad(): # Disable gradient tracking
165
  src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
166
+ tgt = torch.tensor([[1]]).to(device) # < SOS >
167
 
168
  for _ in range(max_length):
169
  output = model(src, tgt)
 
174
 
175
  return tokenizer.decode(tgt.squeeze(0).tolist())
176
 
177
+ # Translation model parameters
178
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
179
  MAX_LEN = 350
180
  BATCH_SIZE = 8
 
187
  MIN_FREQ = 2
188
 
189
  # ==== Tokenizers ====
190
+ try:
191
+ spacy_eng = spacy.load("en_core_web_sm")
192
+ print("Spacy English model loaded successfully")
193
+ except OSError:
194
+ print("Warning: Spacy English model not found. Using simple tokenizer.")
195
+ spacy_eng = None
196
+
197
  def tokenize_en(text):
198
+ if spacy_eng:
199
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
200
+ else:
201
+ # Simple fallback tokenizer
202
+ return text.lower().split()
203
 
204
  def tokenize_te(text):
205
  return text.strip().split(" ")
 
237
 
238
  return torch.tensor(en_ids), torch.tensor(te_ids)
239
 
 
 
 
 
 
 
 
240
  # ==== Transformer Model ====
241
  class Seq2SeqTransformer(nn.Module):
242
  def __init__(self, num_encoder_layers, num_decoder_layers,
 
278
  translated = [te_inv_vocab[idx.item()] for idx in tgt_ids[0][1:]]
279
  return ' '.join(translated[:-1]) if translated[-1] == '<eos>' else ' '.join(translated)
280
 
281
+ # ==== Load Translation Data ====
282
+ try:
283
+ df_telugu = pd.read_csv("merged_translated_responses.csv")
284
+ df_telugu = df_telugu.dropna(subset=['response', 'translated_response'])
285
+ df_telugu['response'] = df_telugu['response'].astype(str)
286
+ df_telugu['translated_response'] = df_telugu['translated_response'].astype(str)
287
+
288
+ # Build vocabularies
289
+ en_vocab = build_vocab(df_telugu['response'], tokenize_en, MIN_FREQ)
290
+ te_vocab = build_vocab(df_telugu['translated_response'], tokenize_te, MIN_FREQ)
291
+ te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
292
+
293
+ print("Telugu translation dataset loaded successfully")
294
+ translation_available = True
295
+ except Exception as e:
296
+ print(f"Error loading Telugu dataset: {e}")
297
+ # Create dummy vocabularies
298
+ en_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3, 'hello': 4, 'world': 5}
299
+ te_vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3, 'హలో': 4, 'ప్రపంచం': 5}
300
+ te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()}
301
+ translation_available = False
302
+
303
+ # Initialize Translation Model
304
+ model_telugu = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
305
+ len(en_vocab), len(te_vocab), NHEAD, FFN_HID_DIM).to(DEVICE)
306
+
307
+ # Load saved weights for translation model
308
+ try:
309
+ model_telugu.load_state_dict(torch.load("english_telugu_transformer.pth", map_location=torch.device('cpu')))
310
+ model_telugu.eval()
311
+ print("Telugu translation model loaded successfully")
312
+ except Exception as e:
313
+ print(f"Error loading Telugu translation model: {e}")
314
+ translation_available = False
315
+
316
+ # Flask App
317
+ app = Flask(__name__)
 
 
318
  CORS(app)
319
 
 
320
  @app.route("/")
321
  def home():
322
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
323
  return jsonify({"message": f"Welcome to TRAVIS API, Time : {current_time}"})
324
 
 
325
  @app.route("/intent")
326
  def intents():
327
+ try:
328
+ if 'intent' in df.columns:
329
+ unique_intents = list(set(df['intent'].dropna()))
330
+ else:
331
+ unique_intents = ["general"] # fallback
332
+ return jsonify({"intents": unique_intents})
333
+ except Exception as e:
334
+ return jsonify({"error": str(e), "intents": ["general"]}), 500
335
 
336
  @app.route("/translate", methods=["POST"])
337
  def translate_text():
338
+ if not translation_available:
339
+ return jsonify({"error": "Translation service not available"}), 503
340
+
341
  data = request.get_json()
342
  text = data.get("text", "")
 
 
343
  if not text:
344
  return jsonify({"error": "Text cannot be empty"}), 400
345
 
346
+ def generate():
347
+ try:
348
+ start = time.time()
349
+ word_count = 0
350
+
351
+ # Translate to Telugu word by word
352
+ telugu_response = translate(model_telugu, text, en_vocab, te_vocab, te_inv_vocab)
353
+
354
+ # Stream each word of the translation
355
+ for word in telugu_response.split():
356
+ word_count += 1
357
+ response_data = {
358
+ "word": word.strip(),
359
+ "timestamp": time.time() - start,
360
+ "word_count": word_count,
361
+ "type": "telugu"
362
+ }
363
+ yield f"data: {json.dumps(response_data)}\n\n"
364
+ except Exception as e:
365
+ error_data = {"error": str(e), "type": "error"}
366
+ yield f"data: {json.dumps(error_data)}\n\n"
367
+
368
+ return Response(
369
+ stream_with_context(generate()),
370
+ mimetype='text/event-stream',
371
+ headers={
372
+ 'Cache-Control': 'no-cache',
373
+ 'Connection': 'keep-alive'
374
+ }
375
+ )
376
 
377
  @app.route("/generate", methods=["POST"])
378
  def generate_text():
379
  data = request.get_json()
380
  query = data.get("query", "")
 
 
 
381
  if not query:
382
  return jsonify({"error": "Query cannot be empty"}), 400
383
+
384
+ def generate():
385
+ try:
386
+ start = time.time()
387
+ word_count = 0
388
+ model.eval()
389
+
390
+ with torch.no_grad():
391
+ src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
392
+ tgt = torch.tensor([[1]]).to(device) # < SOS >
393
+
394
+ for _ in range(200): # max_length
395
+ output = model(src, tgt)
396
+ next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
397
+ tgt = torch.cat([tgt, next_token], dim=1)
398
+
399
+ if next_token.item() == 2: # <EOS>
400
+ break
401
+
402
+ word = tokenizer.idx2word.get(next_token.item(), "<UNK>")
403
+ if word not in ["<PAD>", "<EOS>", "< SOS >"]:
404
+ word_count += 1
405
+ response_data = {
406
+ "word": word.strip(),
407
+ "timestamp": time.time() - start,
408
+ "word_count": word_count,
409
+ "type": "english"
410
+ }
411
+ yield f"data: {json.dumps(response_data)}\n\n"
412
+ except Exception as e:
413
+ error_data = {"error": str(e), "type": "error"}
414
+ yield f"data: {json.dumps(error_data)}\n\n"
415
 
416
+ return Response(
417
+ stream_with_context(generate()),
418
+ mimetype='text/event-stream',
419
+ headers={
420
+ 'Cache-Control': 'no-cache',
421
+ 'Connection': 'keep-alive'
422
+ }
423
+ )
424
 
425
  @app.route("/query", methods=["POST"])
426
  def query_model():
 
427
  data = request.get_json()
 
 
428
  query = data.get("query", "")
 
429
  if not query:
430
  return jsonify({"error": "Query cannot be empty"}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
+ def generate():
433
+ try:
434
+ start = time.time()
435
+ word_count = 0
436
+ model.eval()
437
+
438
+ with torch.no_grad():
439
+ # Generate English response
440
+ src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
441
+ tgt = torch.tensor([[1]]).to(device) # < SOS >
442
+
443
+ english_words = []
444
+ for _ in range(200): # max_length
445
+ output = model(src, tgt)
446
+ next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
447
+ tgt = torch.cat([tgt, next_token], dim=1)
448
+
449
+ if next_token.item() == 2: # <EOS>
450
+ break
451
+
452
+ word = tokenizer.idx2word.get(next_token.item(), "<UNK>")
453
+ if word not in ["<PAD>", "<EOS>", "< SOS >"]:
454
+ english_words.append(word.strip())
455
+ word_count += 1
456
+ response_data = {
457
+ "word": word.strip(),
458
+ "timestamp": time.time() - start,
459
+ "word_count": word_count,
460
+ "type": "english"
461
+ }
462
+ yield f"data: {json.dumps(response_data)}\n\n"
463
+
464
+ # Translate to Telugu if available
465
+ if translation_available:
466
+ english_response = " ".join(english_words)
467
+ telugu_response = translate(model_telugu, english_response, en_vocab, te_vocab, te_inv_vocab)
468
+
469
+ for word in telugu_response.split():
470
+ word_count += 1
471
+ response_data = {
472
+ "word": word.strip(),
473
+ "timestamp": time.time() - start,
474
+ "word_count": word_count,
475
+ "type": "telugu"
476
+ }
477
+ yield f"data: {json.dumps(response_data)}\n\n"
478
+ except Exception as e:
479
+ error_data = {"error": str(e), "type": "error"}
480
+ yield f"data: {json.dumps(error_data)}\n\n"
481
+
482
+ return Response(
483
+ stream_with_context(generate()),
484
+ mimetype='text/event-stream',
485
+ headers={
486
+ 'Cache-Control': 'no-cache',
487
+ 'Connection': 'keep-alive'
488
+ }
489
+ )
490
+
491
  @app.route("/audio", methods=["POST"])
492
  def get_audio():
493
  data = request.get_json()
494
  text = data.get("text")
495
+
 
496
  if not text:
497
  return jsonify({"error": "No Response To convert to speech"}), 400
498
 
499
+ try:
500
+ start_te = time.time()
501
+ # Convert text to Telugu speech using in-memory file
502
+ speech = gTTS(text=text, lang="te")
503
+ audio_io = io.BytesIO()
504
+ speech.write_to_fp(audio_io)
505
+ audio_io.seek(0)
506
+ end_te = time.time()
507
+ print("telugu_time: ", (end_te - start_te))
508
+
509
+ return send_file(audio_io, mimetype="audio/mpeg", as_attachment=False)
510
+ except Exception as e:
511
+ return jsonify({"error": f"Audio generation failed: {str(e)}"}), 500
512
+
513
+ if __name__ == "__main__":
514
+ print("Starting Flask application...")
515
+ print(f"Translation service available: {translation_available}")
516
+ app.run(host="0.0.0.0", debug=True)