Gaston895 commited on
Commit
f5b1522
·
verified ·
1 Parent(s): d9bb268

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -13
app.py CHANGED
@@ -53,6 +53,9 @@ HTML_TEMPLATE = """
53
  <div id="chat-container" class="chat-container">
54
  <div class="message ai-message">
55
  Hello! I'm AEGIS Economics AI. Ask me about economic policies, market analysis, or financial strategies.
 
 
 
56
  </div>
57
  </div>
58
 
@@ -63,6 +66,49 @@ HTML_TEMPLATE = """
63
  </div>
64
 
65
  <script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  function handleKeyPress(event) {
67
  if (event.key === 'Enter') {
68
  sendMessage();
@@ -152,8 +198,8 @@ def load_model():
152
  logger.info(f"Loading model from {model_repo}...")
153
  model = AutoModelForCausalLM.from_pretrained(
154
  model_repo,
155
- torch_dtype=torch.bfloat16,
156
- device_map="auto",
157
  trust_remote_code=True,
158
  use_auth_token=False,
159
  low_cpu_mem_usage=True
@@ -164,13 +210,31 @@ def load_model():
164
 
165
  except Exception as e:
166
  logger.error(f"Error loading model from HF: {str(e)}")
167
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def generate_response(prompt):
170
  """Generate response using the loaded model"""
171
  try:
172
  if model is None or tokenizer is None:
173
- return "Model not loaded. Please wait..."
174
 
175
  # Economics-focused system prompt
176
  system_prompt = """You are AEGIS Economics AI, an expert economic analyst and policy advisor.
@@ -180,17 +244,18 @@ def generate_response(prompt):
180
  full_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:"
181
 
182
  # Tokenize input
183
- inputs = tokenizer(full_prompt, return_tensors="pt")
184
 
185
  # Generate response
186
  with torch.no_grad():
187
  outputs = model.generate(
188
  inputs.input_ids,
189
- max_new_tokens=512,
190
  temperature=0.7,
191
  do_sample=True,
192
  pad_token_id=tokenizer.eos_token_id,
193
- repetition_penalty=1.1
 
194
  )
195
 
196
  # Decode response
@@ -204,7 +269,7 @@ def generate_response(prompt):
204
 
205
  except Exception as e:
206
  logger.error(f"Error generating response: {str(e)}")
207
- return "I apologize, but I'm having trouble processing your request right now."
208
 
209
  @app.route('/')
210
  def home():
@@ -236,16 +301,34 @@ def health():
236
  return jsonify({
237
  'status': 'healthy',
238
  'model_loaded': model is not None,
239
- 'tokenizer_loaded': tokenizer is not None
 
240
  })
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  if __name__ == '__main__':
243
  # Load model on startup
244
  logger.info("Starting AEGIS Economics AI...")
245
 
246
- if load_model():
 
 
 
 
247
  logger.info("Model loaded successfully, starting server...")
248
- app.run(host='0.0.0.0', port=7860, debug=False)
249
  else:
250
- logger.error("Failed to load model, exiting...")
251
- exit(1)
 
 
53
  <div id="chat-container" class="chat-container">
54
  <div class="message ai-message">
55
  Hello! I'm AEGIS Economics AI. Ask me about economic policies, market analysis, or financial strategies.
56
+ <div id="model-status" style="font-size: 0.8em; color: #666; margin-top: 5px;">
57
+ Checking model status...
58
+ </div>
59
  </div>
60
  </div>
61
 
 
66
  </div>
67
 
68
  <script>
69
+ // Check model status on page load
70
+ async function checkModelStatus() {
71
+ try {
72
+ const response = await fetch('/health');
73
+ const data = await response.json();
74
+ const statusDiv = document.getElementById('model-status');
75
+
76
+ if (data.model_loaded) {
77
+ statusDiv.textContent = '✅ Model loaded and ready!';
78
+ statusDiv.style.color = '#28a745';
79
+ } else {
80
+ statusDiv.textContent = '⏳ Model loading... Please wait.';
81
+ statusDiv.style.color = '#ffc107';
82
+ // Try to load model
83
+ setTimeout(tryLoadModel, 2000);
84
+ }
85
+ } catch (error) {
86
+ const statusDiv = document.getElementById('model-status');
87
+ statusDiv.textContent = '❌ Connection error';
88
+ statusDiv.style.color = '#dc3545';
89
+ }
90
+ }
91
+
92
+ async function tryLoadModel() {
93
+ try {
94
+ const response = await fetch('/load_model', { method: 'POST' });
95
+ const data = await response.json();
96
+
97
+ if (data.success) {
98
+ const statusDiv = document.getElementById('model-status');
99
+ statusDiv.textContent = '✅ Model loaded successfully!';
100
+ statusDiv.style.color = '#28a745';
101
+ } else {
102
+ setTimeout(checkModelStatus, 5000); // Check again in 5 seconds
103
+ }
104
+ } catch (error) {
105
+ setTimeout(checkModelStatus, 5000);
106
+ }
107
+ }
108
+
109
+ // Call on page load
110
+ window.onload = checkModelStatus;
111
+
112
  function handleKeyPress(event) {
113
  if (event.key === 'Enter') {
114
  sendMessage();
 
198
  logger.info(f"Loading model from {model_repo}...")
199
  model = AutoModelForCausalLM.from_pretrained(
200
  model_repo,
201
+ torch_dtype=torch.float16, # Changed from bfloat16 for better compatibility
202
+ device_map="cpu", # Force CPU for HF Spaces compatibility
203
  trust_remote_code=True,
204
  use_auth_token=False,
205
  low_cpu_mem_usage=True
 
210
 
211
  except Exception as e:
212
  logger.error(f"Error loading model from HF: {str(e)}")
213
+ # Try alternative loading method
214
+ try:
215
+ logger.info("Trying alternative loading method...")
216
+ tokenizer = AutoTokenizer.from_pretrained(
217
+ "Qwen/Qwen2-1.5B", # Fallback to base model
218
+ trust_remote_code=True
219
+ )
220
+ model = AutoModelForCausalLM.from_pretrained(
221
+ "Qwen/Qwen2-1.5B",
222
+ torch_dtype=torch.float16,
223
+ device_map="cpu",
224
+ trust_remote_code=True,
225
+ low_cpu_mem_usage=True
226
+ )
227
+ logger.info("Fallback model loaded successfully!")
228
+ return True
229
+ except Exception as e2:
230
+ logger.error(f"Fallback loading also failed: {str(e2)}")
231
+ return False
232
 
233
  def generate_response(prompt):
234
  """Generate response using the loaded model"""
235
  try:
236
  if model is None or tokenizer is None:
237
+ return "Model is still loading, please wait a moment and try again..."
238
 
239
  # Economics-focused system prompt
240
  system_prompt = """You are AEGIS Economics AI, an expert economic analyst and policy advisor.
 
244
  full_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:"
245
 
246
  # Tokenize input
247
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=1024)
248
 
249
  # Generate response
250
  with torch.no_grad():
251
  outputs = model.generate(
252
  inputs.input_ids,
253
+ max_new_tokens=256, # Reduced for faster generation
254
  temperature=0.7,
255
  do_sample=True,
256
  pad_token_id=tokenizer.eos_token_id,
257
+ repetition_penalty=1.1,
258
+ no_repeat_ngram_size=3
259
  )
260
 
261
  # Decode response
 
269
 
270
  except Exception as e:
271
  logger.error(f"Error generating response: {str(e)}")
272
+ return "I apologize, but I'm having trouble processing your request right now. Please try again in a moment."
273
 
274
  @app.route('/')
275
  def home():
 
301
  return jsonify({
302
  'status': 'healthy',
303
  'model_loaded': model is not None,
304
+ 'tokenizer_loaded': tokenizer is not None,
305
+ 'model_info': 'Gaston895/Aegisecon1' if model is not None else 'Not loaded'
306
  })
307
 
308
+ @app.route('/load_model', methods=['POST'])
309
+ def load_model_endpoint():
310
+ """Endpoint to trigger model loading"""
311
+ try:
312
+ success = load_model()
313
+ return jsonify({
314
+ 'success': success,
315
+ 'model_loaded': model is not None,
316
+ 'tokenizer_loaded': tokenizer is not None
317
+ })
318
+ except Exception as e:
319
+ return jsonify({'error': str(e)}), 500
320
+
321
  if __name__ == '__main__':
322
  # Load model on startup
323
  logger.info("Starting AEGIS Economics AI...")
324
 
325
+ # Try to load model, but don't fail if it doesn't work
326
+ logger.info("Attempting to load model...")
327
+ model_loaded = load_model()
328
+
329
+ if model_loaded:
330
  logger.info("Model loaded successfully, starting server...")
 
331
  else:
332
+ logger.warning("Model failed to load, starting server anyway. Model can be loaded via /load_model endpoint.")
333
+
334
+ app.run(host='0.0.0.0', port=7860, debug=False)