HusainHG commited on
Commit
29601ae
Β·
verified Β·
1 Parent(s): 0e86bf8

Upload 5 files

Browse files
Files changed (1) hide show
  1. app.py +30 -9
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import torch
5
  import sys
6
  import os
 
7
 
8
  app = Flask(__name__)
9
  CORS(app)
@@ -21,13 +22,18 @@ if tokenizer.pad_token is None:
21
 
22
  print("βœ… Tokenizer loaded!")
23
 
 
24
  quantization_config = BitsAndBytesConfig(
25
  load_in_4bit=True,
26
  bnb_4bit_compute_dtype=torch.float16,
27
  bnb_4bit_quant_type="nf4",
28
- bnb_4bit_use_double_quant=True,
29
  )
30
 
 
 
 
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  MODEL_NAME,
33
  quantization_config=quantization_config,
@@ -37,8 +43,14 @@ model = AutoModelForCausalLM.from_pretrained(
37
  torch_dtype=torch.float16,
38
  )
39
 
40
- print("βœ… Model loaded!")
 
 
 
 
 
41
  print(f"Device: {model.device}")
 
42
  print("="*80 + "\n")
43
 
44
  HTML_TEMPLATE = """
@@ -146,7 +158,7 @@ HTML_TEMPLATE = """
146
  <button onclick="generate()" id="generateBtn">πŸ’¬ Send</button>
147
 
148
  <div class="loading" id="loading">
149
- <p>⏳ Generating response... Please wait (this may take 30-60 seconds on CPU)</p>
150
  </div>
151
 
152
  <div class="output" id="output"></div>
@@ -237,16 +249,18 @@ def generate():
237
  sys.stdout.flush()
238
 
239
  with torch.no_grad():
240
- torch.set_num_threads(2)
241
  outputs = model.generate(
242
  **inputs,
243
- max_new_tokens=200,
244
  do_sample=True,
245
- temperature=0.7,
246
- top_p=0.9,
 
 
247
  pad_token_id=tokenizer.pad_token_id,
248
  eos_token_id=tokenizer.eos_token_id,
249
- use_cache=False
 
250
  )
251
 
252
  full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -273,6 +287,13 @@ def generate():
273
  return jsonify({'error': str(e)}), 500
274
 
275
  if __name__ == '__main__':
 
 
 
 
276
  port = int(os.environ.get('PORT', 7860))
277
  print(f"🌐 Starting server on port {port}...\n")
278
- app.run(host='0.0.0.0', port=port, debug=False)
 
 
 
 
4
  import torch
5
  import sys
6
  import os
7
+ import gc # For garbage collection optimization
8
 
9
  app = Flask(__name__)
10
  CORS(app)
 
22
 
23
  print("βœ… Tokenizer loaded!")
24
 
25
+ # Optimized quantization for 2 vCPU + 18GB RAM
26
  quantization_config = BitsAndBytesConfig(
27
  load_in_4bit=True,
28
  bnb_4bit_compute_dtype=torch.float16,
29
  bnb_4bit_quant_type="nf4",
30
+ bnb_4bit_use_double_quant=False, # Disabled for CPU efficiency
31
  )
32
 
33
+ # Set CPU threads BEFORE loading model to reduce startup CPU spike
34
+ torch.set_num_threads(2)
35
+ torch.set_num_interop_threads(1)
36
+
37
  model = AutoModelForCausalLM.from_pretrained(
38
  MODEL_NAME,
39
  quantization_config=quantization_config,
 
43
  torch_dtype=torch.float16,
44
  )
45
 
46
+ # Set model to eval mode and optimize for inference
47
+ model.eval()
48
+ for param in model.parameters():
49
+ param.requires_grad = False
50
+
51
+ print("βœ… Model loaded and optimized!")
52
  print(f"Device: {model.device}")
53
+ print(f"Threads: {torch.get_num_threads()}")
54
  print("="*80 + "\n")
55
 
56
  HTML_TEMPLATE = """
 
158
  <button onclick="generate()" id="generateBtn">πŸ’¬ Send</button>
159
 
160
  <div class="loading" id="loading">
161
+ <p>⏳ Generating response... Please wait (typically 15-30 seconds on 2 vCPU)</p>
162
  </div>
163
 
164
  <div class="output" id="output"></div>
 
249
  sys.stdout.flush()
250
 
251
  with torch.no_grad():
 
252
  outputs = model.generate(
253
  **inputs,
254
+ max_new_tokens=150, # Reduced for faster response
255
  do_sample=True,
256
+ temperature=0.3, # Lower temp = faster, more focused
257
+ top_p=0.85, # Slightly lower for efficiency
258
+ top_k=40, # Limit sampling space
259
+ repetition_penalty=1.1, # Prevent loops
260
  pad_token_id=tokenizer.pad_token_id,
261
  eos_token_id=tokenizer.eos_token_id,
262
+ use_cache=True, # Enable KV cache for speed
263
+ num_beams=1, # Greedy = faster
264
  )
265
 
266
  full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
287
  return jsonify({'error': str(e)}), 500
288
 
289
  if __name__ == '__main__':
290
+ # Force garbage collection after model load
291
+ import gc
292
+ gc.collect()
293
+
294
  port = int(os.environ.get('PORT', 7860))
295
  print(f"🌐 Starting server on port {port}...\n")
296
+ print("πŸ’‘ CPU usage should normalize after initial model load\n")
297
+
298
+ # Use threaded mode for better concurrency on 2 vCPU
299
+ app.run(host='0.0.0.0', port=port, debug=False, threaded=True, use_reloader=False)