Stanley03 commited on
Commit
8d90dfc
·
verified ·
1 Parent(s): e996f05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -126
app.py CHANGED
@@ -2,139 +2,70 @@ from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
- import os
6
  import time
7
 
8
  app = Flask(__name__)
9
  CORS(app)
10
 
11
- # Global variables for model caching
12
  model = None
13
  tokenizer = None
14
  model_loaded = False
15
 
16
- # Simba system message
17
  SIMBA_SYSTEM = """You are Simba from The Lion King. You're brave, playful, and wise.
18
  Speak with royal confidence but also warmth and humor. Remember: "Hakuna Matata",
19
  relationships with Nala, Timon, Pumbaa, and your journey to reclaim Pride Rock.
20
  Keep responses under 2 sentences and stay in character."""
21
 
22
  def load_model():
23
- """Load model with optimizations - called once at startup"""
24
  global model, tokenizer, model_loaded
25
-
26
  if model_loaded:
27
  return
28
 
29
- print("🚀 Loading optimized Qwen2.5-0.5B model...")
30
- start_time = time.time()
31
-
32
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
33
 
34
- # 🎯 SPEED OPTIMIZATION 1: Use bfloat16 for faster inference
35
- torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
36
-
37
- # 🎯 SPEED OPTIMIZATION 2: Faster tokenizer
38
- tokenizer = AutoTokenizer.from_pretrained(
39
- model_name,
40
- trust_remote_code=True,
41
- padding_side="left" # Better for batch processing
42
- )
43
-
44
  if tokenizer.pad_token is None:
45
  tokenizer.pad_token = tokenizer.eos_token
46
 
47
- # 🎯 SPEED OPTIMIZATION 3: Optimized model loading
48
  model = AutoModelForCausalLM.from_pretrained(
49
  model_name,
50
- torch_dtype=torch_dtype,
51
  device_map="auto",
52
- trust_remote_code=True,
53
- attn_implementation="sdpa", # Flash Attention 2 for speed
54
- use_cache=True, # Faster generation
55
- low_cpu_mem_usage=True,
56
  )
57
-
58
- # 🎯 SPEED OPTIMIZATION 4: Compile model for faster inference (PyTorch 2.0+)
59
- if hasattr(torch, 'compile') and torch.cuda.is_available():
60
- print("🔧 Compiling model for maximum speed...")
61
- model = torch.compile(model, mode="reduce-overhead", fullgraph=False)
62
-
63
  model_loaded = True
64
- load_time = time.time() - start_time
65
- print(f"✅ Model loaded in {load_time:.2f} seconds!")
66
 
67
- # Load model when app starts
68
  load_model()
69
 
70
  def generate_response(user_message):
71
- """Generate optimized response with speed enhancements"""
72
- start_time = time.time()
73
-
74
- # Create conversation format for Qwen
75
  messages = [
76
  {"role": "system", "content": SIMBA_SYSTEM},
77
  {"role": "user", "content": user_message}
78
  ]
79
 
80
- # 🎯 SPEED OPTIMIZATION 5: Efficient template application
81
- text = tokenizer.apply_chat_template(
82
- messages,
83
- tokenize=False,
84
- add_generation_prompt=True
85
- )
86
-
87
- # 🎯 SPEED OPTIMIZATION 6: Optimized tokenization
88
- inputs = tokenizer(
89
- text,
90
- return_tensors="pt",
91
- padding=True,
92
- truncation=True,
93
- max_length=512
94
- ).to(model.device)
95
 
96
- # 🎯 SPEED OPTIMIZATION 7: Faster generation parameters
97
  with torch.no_grad():
98
- # Use inference mode for speed
99
- with torch.inference_mode():
100
- outputs = model.generate(
101
- **inputs,
102
- max_new_tokens=100, # Reduced for speed
103
- temperature=0.7,
104
- do_sample=True,
105
- top_p=0.9,
106
- top_k=40,
107
- repetition_penalty=1.1,
108
- pad_token_id=tokenizer.eos_token_id,
109
- eos_token_id=tokenizer.eos_token_id,
110
- num_return_sequences=1,
111
- early_stopping=True
112
- )
113
-
114
- # 🎯 SPEED OPTIMIZATION 8: Efficient decoding
115
- response = tokenizer.decode(
116
- outputs[0][inputs['input_ids'].shape[1]:],
117
- skip_special_tokens=True
118
- )
119
-
120
- generation_time = time.time() - start_time
121
- print(f"⚡ Response generated in {generation_time:.2f} seconds")
122
-
123
  return response.strip()
124
 
125
  @app.route('/')
126
  def home():
127
- return jsonify({
128
- "message": "Simba AI API is running! 🦁",
129
- "status": "optimized",
130
- "model": "Qwen2.5-0.5B-Instruct"
131
- })
132
 
133
- @app.route('/api/chat', methods=['POST', 'OPTIONS'])
134
  def chat():
135
- if request.method == 'OPTIONS':
136
- return '', 200
137
-
138
  try:
139
  data = request.get_json()
140
  user_message = data.get('message', '')
@@ -142,54 +73,18 @@ def chat():
142
  if not user_message:
143
  return jsonify({"error": "No message provided"}), 400
144
 
145
- # 🎯 SPEED OPTIMIZATION 9: Input validation and truncation
146
- if len(user_message) > 500:
147
- user_message = user_message[:500] + "..."
148
-
149
  response = generate_response(user_message)
150
 
151
  return jsonify({
152
  "response": response,
153
- "status": "success",
154
- "model": "Qwen2.5-0.5B"
155
  })
156
 
157
  except Exception as e:
158
- print(f"❌ Error: {str(e)}")
159
  return jsonify({
160
- "error": "Hakuna Matata! Even kings have technical issues. Try again!",
161
  "status": "error"
162
  }), 500
163
 
164
- @app.route('/health')
165
- def health():
166
- return jsonify({
167
- "status": "healthy",
168
- "model_loaded": model_loaded,
169
- "device": str(model.device) if model else "none"
170
- })
171
-
172
- @app.route('/status')
173
- def status():
174
- gpu_info = "CPU"
175
- if torch.cuda.is_available():
176
- gpu_info = f"GPU: {torch.cuda.get_device_name()}, Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB"
177
-
178
- return jsonify({
179
- "status": "running",
180
- "model": "Qwen2.5-0.5B-Instruct",
181
- "optimizations": "enabled",
182
- "hardware": gpu_info,
183
- "torch_version": torch.__version__
184
- })
185
-
186
- # CORS headers
187
- @app.after_request
188
- def after_request(response):
189
- response.headers.add('Access-Control-Allow-Origin', '*')
190
- response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
191
- response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
192
- return response
193
-
194
  if __name__ == '__main__':
195
- app.run(debug=False, host='0.0.0.0', port=7860)
 
2
  from flask_cors import CORS
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
5
  import time
6
 
7
  app = Flask(__name__)
8
  CORS(app)
9
 
 
10
  model = None
11
  tokenizer = None
12
  model_loaded = False
13
 
 
14
  SIMBA_SYSTEM = """You are Simba from The Lion King. You're brave, playful, and wise.
15
  Speak with royal confidence but also warmth and humor. Remember: "Hakuna Matata",
16
  relationships with Nala, Timon, Pumbaa, and your journey to reclaim Pride Rock.
17
  Keep responses under 2 sentences and stay in character."""
18
 
19
  def load_model():
 
20
  global model, tokenizer, model_loaded
 
21
  if model_loaded:
22
  return
23
 
24
+ print("Loading Qwen2.5-0.5B model...")
 
 
25
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
26
 
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
28
  if tokenizer.pad_token is None:
29
  tokenizer.pad_token = tokenizer.eos_token
30
 
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
+ torch_dtype=torch.float16,
34
  device_map="auto",
35
+ trust_remote_code=True
 
 
 
36
  )
 
 
 
 
 
 
37
  model_loaded = True
38
+ print("Model loaded!")
 
39
 
 
40
  load_model()
41
 
42
  def generate_response(user_message):
 
 
 
 
43
  messages = [
44
  {"role": "system", "content": SIMBA_SYSTEM},
45
  {"role": "user", "content": user_message}
46
  ]
47
 
48
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
49
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
51
  with torch.no_grad():
52
+ outputs = model.generate(
53
+ **inputs,
54
+ max_new_tokens=150,
55
+ temperature=0.7,
56
+ do_sample=True,
57
+ pad_token_id=tokenizer.eos_token_id
58
+ )
59
+
60
+ response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  return response.strip()
62
 
63
  @app.route('/')
64
  def home():
65
+ return jsonify({"message": "Simba AI API is running! 🦁"})
 
 
 
 
66
 
67
+ @app.route('/api/chat', methods=['POST'])
68
  def chat():
 
 
 
69
  try:
70
  data = request.get_json()
71
  user_message = data.get('message', '')
 
73
  if not user_message:
74
  return jsonify({"error": "No message provided"}), 400
75
 
 
 
 
 
76
  response = generate_response(user_message)
77
 
78
  return jsonify({
79
  "response": response,
80
+ "status": "success"
 
81
  })
82
 
83
  except Exception as e:
 
84
  return jsonify({
85
+ "error": str(e),
86
  "status": "error"
87
  }), 500
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  if __name__ == '__main__':
90
+ app.run(debug=True, host='0.0.0.0', port=7860)