vinay0123 commited on
Commit
cd9d203
·
verified ·
1 Parent(s): 5154835

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -45
app.py CHANGED
@@ -13,9 +13,6 @@ import json
13
  torch.set_num_threads(os.cpu_count())
14
  torch.set_num_interop_threads(os.cpu_count())
15
 
16
- # Enable optimizations
17
- torch.backends.mkldnn.enabled = True if hasattr(torch.backends, 'mkldnn') else False
18
-
19
  url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY"
20
  df = pd.read_csv(url)
21
 
@@ -76,9 +73,6 @@ def load_model(model, path="gpt_model.pth"):
76
  if os.path.exists(path):
77
  model.load_state_dict(torch.load(path, map_location=device, weights_only=True))
78
  model.eval()
79
- # Enable inference optimizations
80
- if hasattr(torch.jit, 'optimize_for_inference'):
81
- model = torch.jit.optimize_for_inference(torch.jit.script(model))
82
  print("Model loaded successfully.")
83
  else:
84
  print("Model file not found!")
@@ -92,42 +86,30 @@ def generate_response_stream(model, query, max_length=200):
92
  src = torch.tensor(src_tokens).unsqueeze(0).to(device)
93
  tgt = torch.tensor([[1]], dtype=torch.long).to(device) # < SOS >
94
 
95
- # Pre-allocate tensor for better memory efficiency
96
- max_tgt_len = min(max_length, 200)
97
-
98
  with torch.no_grad():
99
- # Use torch.inference_mode for better performance
100
- with torch.inference_mode():
101
- for step in range(max_length):
102
- # Forward pass
103
- output = model(src, tgt)
104
-
105
- # Get next token more efficiently
106
- logits = output[:, -1, :]
107
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
108
-
109
- # Check for EOS early
110
- if next_token.item() == 2: # <EOS>
111
- break
112
-
113
- # Concatenate token
114
- tgt = torch.cat([tgt, next_token], dim=1)
115
-
116
- # Get the current word
117
- current_word = tokenizer.idx2word.get(next_token.item(), "<UNK>")
118
- if current_word not in ["<PAD>", "<EOS>", "< SOS >"]:
119
- yield current_word + " "
120
-
121
- # Prevent infinite loops
122
- if tgt.size(1) >= max_tgt_len:
123
- break
124
-
125
- # Flask App with threading optimizations
126
  app = Flask(__name__)
127
 
128
- # Configure Flask for better performance
129
- app.config['THREADED'] = True
130
-
131
  @app.route("/")
132
  def home():
133
  return {"message": "Streaming Transformer-based Response Generator API is running!"}
@@ -160,20 +142,18 @@ def query_model():
160
  mimetype='text/event-stream',
161
  headers={
162
  'Cache-Control': 'no-cache',
163
- 'Connection': 'keep-alive',
164
- 'X-Accel-Buffering': 'no' # Disable nginx buffering if present
165
  }
166
  )
167
 
168
  if __name__ == "__main__":
169
- # Load and optimize model
170
  model = load_model(model)
171
 
172
- # Run Flask with threading enabled and optimized worker settings
173
  app.run(
174
  host="0.0.0.0",
175
  port=7860,
176
  threaded=True,
177
- processes=1, # Use threading instead of multiprocessing for better memory sharing
178
- debug=False # Disable debug mode for better performance
179
  )
 
13
  torch.set_num_threads(os.cpu_count())
14
  torch.set_num_interop_threads(os.cpu_count())
15
 
 
 
 
16
  url = "https://drive.google.com/uc?id=1RCZShB5ohy1HdU-mogcP16TbeVv9txpY"
17
  df = pd.read_csv(url)
18
 
 
73
  if os.path.exists(path):
74
  model.load_state_dict(torch.load(path, map_location=device, weights_only=True))
75
  model.eval()
 
 
 
76
  print("Model loaded successfully.")
77
  else:
78
  print("Model file not found!")
 
86
  src = torch.tensor(src_tokens).unsqueeze(0).to(device)
87
  tgt = torch.tensor([[1]], dtype=torch.long).to(device) # < SOS >
88
 
 
 
 
89
  with torch.no_grad():
90
+ for step in range(max_length):
91
+ # Forward pass
92
+ output = model(src, tgt)
93
+
94
+ # Get next token more efficiently
95
+ logits = output[:, -1, :]
96
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
97
+
98
+ # Check for EOS early
99
+ if next_token.item() == 2: # <EOS>
100
+ break
101
+
102
+ # Concatenate token
103
+ tgt = torch.cat([tgt, next_token], dim=1)
104
+
105
+ # Get the current word
106
+ current_word = tokenizer.idx2word.get(next_token.item(), "<UNK>")
107
+ if current_word not in ["<PAD>", "<EOS>", "< SOS >"]:
108
+ yield current_word + " "
109
+
110
+ # Flask App
 
 
 
 
 
 
111
  app = Flask(__name__)
112
 
 
 
 
113
  @app.route("/")
114
  def home():
115
  return {"message": "Streaming Transformer-based Response Generator API is running!"}
 
142
  mimetype='text/event-stream',
143
  headers={
144
  'Cache-Control': 'no-cache',
145
+ 'Connection': 'keep-alive'
 
146
  }
147
  )
148
 
149
  if __name__ == "__main__":
150
+ # Load model
151
  model = load_model(model)
152
 
153
+ # Run Flask with optimizations
154
  app.run(
155
  host="0.0.0.0",
156
  port=7860,
157
  threaded=True,
158
+ debug=False
 
159
  )