Gaston895 commited on
Commit
4f57881
·
verified ·
1 Parent(s): 756e842

🔧 Fix model loading using proven app_gunicorn.py approach - no pipeline, direct generation

Browse files
Files changed (1) hide show
  1. app.py +73 -110
app.py CHANGED
@@ -7,7 +7,7 @@ CPU-optimized version for Modal deployment
7
 
8
  from flask import Flask, request, jsonify, render_template_string
9
  import torch
10
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
  import os
12
  import logging
13
  import json
@@ -38,7 +38,6 @@ app = Flask(__name__)
38
  # Global variables
39
  model = None
40
  tokenizer = None
41
- chat_pipeline = None
42
  executor = ThreadPoolExecutor(max_workers=2)
43
 
44
  def cleanup_memory():
@@ -848,140 +847,104 @@ Focus on quantitative metrics and actionable insights."""
848
  langgraph_processor = LangGraphProcessor()
849
 
850
  def load_model():
851
- """Load the model and tokenizer from Gaston895/Aegisecon1 repository using optimized pipeline approach"""
852
  global model, tokenizer, chat_pipeline
853
 
854
  try:
855
  logger.info("Loading model and tokenizer from Hugging Face...")
856
 
857
- # Model configuration
858
- MODEL_NAME = "Gaston895/Aegisecon1"
859
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
860
 
861
- logger.info(f"Loading tokenizer from {MODEL_NAME}...")
862
- # Load tokenizer with optimizations
863
  tokenizer = AutoTokenizer.from_pretrained(
864
- MODEL_NAME,
865
  trust_remote_code=True,
866
- use_fast=True # Use fast tokenizer for speed
867
  )
868
 
869
- # Set pad token if not exists
870
- if tokenizer.pad_token is None:
871
- tokenizer.pad_token = tokenizer.eos_token
872
-
873
- logger.info(f"Loading model from {MODEL_NAME}...")
874
- # Load model with memory-optimized settings
875
  model = AutoModelForCausalLM.from_pretrained(
876
- MODEL_NAME,
877
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Use float16 for memory efficiency
878
- device_map="auto" if DEVICE == "cuda" else None,
879
  trust_remote_code=True,
880
- low_cpu_mem_usage=True,
881
- use_cache=True, # Enable KV cache for faster generation
882
- attn_implementation="flash_attention_2" if DEVICE == "cuda" else None # Use flash attention if available
883
  )
884
 
885
- # Optimize model for inference
886
- model.eval()
887
- if hasattr(model, 'gradient_checkpointing_disable'):
888
- model.gradient_checkpointing_disable()
889
-
890
- # Create pipeline with optimized settings
891
- chat_pipeline = pipeline(
892
- "text-generation",
893
- model=model,
894
- tokenizer=tokenizer,
895
- device=0 if DEVICE == "cuda" else -1,
896
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
897
- model_kwargs={
898
- "use_cache": True,
899
- "do_sample": True,
900
- "pad_token_id": tokenizer.eos_token_id
901
- }
902
- )
903
 
904
  logger.info("Model loaded successfully from HF repository!")
905
- logger.info(f"Device: {DEVICE}")
906
  logger.info(f"Model dtype: {next(model.parameters()).dtype}")
907
 
908
- # Clear any initialization memory
909
- if torch.cuda.is_available():
910
- torch.cuda.empty_cache()
911
-
912
  return True
913
 
914
  except Exception as e:
915
- logger.error(f"Failed to load model: {str(e)}")
916
- # Clear memory on failure
917
- if torch.cuda.is_available():
918
- torch.cuda.empty_cache()
919
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
 
921
  def generate_response(prompt, temperature=0.7):
922
- """Generate response using the loaded model pipeline with optimizations"""
923
  try:
924
- if not chat_pipeline:
925
  return "Model is still loading, please wait a moment and try again..."
926
 
927
- # Truncate input prompt if too long to save memory
928
- max_input_length = 800
929
- if len(prompt) > max_input_length:
930
- prompt = prompt[:max_input_length] + "..."
931
-
932
- # Format the prompt efficiently
933
- formatted_prompt = f"User: {prompt}\nAssistant:"
934
-
935
- # Optimized generation parameters for speed and memory
936
- response = chat_pipeline(
937
- formatted_prompt,
938
- max_new_tokens=128, # Reduced for faster generation
939
- temperature=temperature,
940
- do_sample=True,
941
- top_p=0.9, # Nucleus sampling for better quality
942
- top_k=50, # Limit vocabulary for speed
943
- pad_token_id=tokenizer.eos_token_id,
944
- eos_token_id=tokenizer.eos_token_id,
945
- truncation=True,
946
- return_full_text=False, # Only return new tokens
947
- clean_up_tokenization_spaces=True
948
- )
949
 
950
- # Extract response efficiently
951
- if response and len(response) > 0:
952
- generated_text = response[0].get('generated_text', '')
953
-
954
- # Clean up the response
955
- if "Assistant:" in generated_text:
956
- assistant_response = generated_text.split("Assistant:")[-1].strip()
957
- else:
958
- assistant_response = generated_text.replace(formatted_prompt, "").strip()
959
-
960
- # Remove any remaining prompt artifacts
961
- if assistant_response.startswith("User:"):
962
- lines = assistant_response.split('\n')
963
- assistant_response = '\n'.join([line for line in lines if not line.strip().startswith("User:")])
964
-
965
- # Ensure response isn't empty
966
- if not assistant_response.strip():
967
- assistant_response = "I understand your question. Let me provide an economic analysis based on the available data."
968
-
969
- return assistant_response.strip()
970
- else:
971
- return "I'm processing your request. Please try again in a moment."
972
 
973
- except torch.cuda.OutOfMemoryError:
974
- # Handle CUDA OOM gracefully
975
- logger.error("CUDA out of memory during generation")
976
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
977
- return "System is under high load. Please try a shorter question."
 
 
 
978
 
979
  except Exception as e:
980
  logger.error(f"Error generating response: {str(e)}")
981
- # Clear any potential memory issues
982
- if torch.cuda.is_available():
983
- torch.cuda.empty_cache()
984
- return "I'm experiencing technical difficulties. Please try again shortly."
985
 
986
  # HTML template (same as before)
987
  HTML_TEMPLATE = """
@@ -1324,7 +1287,7 @@ def load_model_manual():
1324
 
1325
  return jsonify({
1326
  'success': success,
1327
- 'model_loaded': chat_pipeline is not None,
1328
  'tokenizer_loaded': tokenizer is not None,
1329
  'message': 'Model loaded successfully' if success else 'Model loading failed'
1330
  })
@@ -1360,7 +1323,7 @@ def health():
1360
  """Health check endpoint"""
1361
  return jsonify({
1362
  'status': 'healthy',
1363
- 'model_loaded': chat_pipeline is not None,
1364
  'tokenizer_loaded': tokenizer is not None,
1365
  'langgraph_available': LANGGRAPH_AVAILABLE,
1366
  'processing_mode': 'langgraph' if LANGGRAPH_AVAILABLE else 'simplified'
@@ -1395,11 +1358,11 @@ else:
1395
  logger.info("Production mode: Loading model during module import...")
1396
  logger.info(f"LangGraph available: {LANGGRAPH_AVAILABLE}")
1397
 
1398
- # Load model immediately
1399
- logger.info("Loading model from Gaston895/Aegisecon1...")
1400
  model_loaded = load_model()
1401
 
1402
  if model_loaded:
1403
  logger.info("✅ Model loaded successfully for production!")
1404
  else:
1405
- logger.error(" Model failed to load in production mode!")
 
7
 
8
  from flask import Flask, request, jsonify, render_template_string
9
  import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
  import os
12
  import logging
13
  import json
 
38
  # Global variables
39
  model = None
40
  tokenizer = None
 
41
  executor = ThreadPoolExecutor(max_workers=2)
42
 
43
  def cleanup_memory():
 
847
  langgraph_processor = LangGraphProcessor()
848
 
849
  def load_model():
850
+ """Load the model and tokenizer from Gaston895/Aegisecon1 repository using the working approach"""
851
  global model, tokenizer, chat_pipeline
852
 
853
  try:
854
  logger.info("Loading model and tokenizer from Hugging Face...")
855
 
856
+ # Load from the deployed model repository
857
+ model_repo = "Gaston895/Aegisecon1"
 
858
 
859
+ logger.info(f"Loading tokenizer from {model_repo}...")
 
860
  tokenizer = AutoTokenizer.from_pretrained(
861
+ model_repo,
862
  trust_remote_code=True,
863
+ use_auth_token=False
864
  )
865
 
866
+ logger.info(f"Loading model from {model_repo}...")
 
 
 
 
 
867
  model = AutoModelForCausalLM.from_pretrained(
868
+ model_repo,
869
+ torch_dtype=torch.float16, # Use float16 for better compatibility
870
+ device_map="cpu", # Force CPU for HF Spaces compatibility
871
  trust_remote_code=True,
872
+ use_auth_token=False,
873
+ low_cpu_mem_usage=True
 
874
  )
875
 
876
+ # Don't create pipeline - use direct model generation like the working version
877
+ chat_pipeline = None # Set to None to indicate we're using direct generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
 
879
  logger.info("Model loaded successfully from HF repository!")
880
+ logger.info(f"Model device: {next(model.parameters()).device}")
881
  logger.info(f"Model dtype: {next(model.parameters()).dtype}")
882
 
 
 
 
 
883
  return True
884
 
885
  except Exception as e:
886
+ logger.error(f"Error loading model from HF: {str(e)}")
887
+ # Try alternative loading method
888
+ try:
889
+ logger.info("Trying alternative loading method...")
890
+ tokenizer = AutoTokenizer.from_pretrained(
891
+ "Qwen/Qwen2-1.5B", # Fallback to base model
892
+ trust_remote_code=True
893
+ )
894
+ model = AutoModelForCausalLM.from_pretrained(
895
+ "Qwen/Qwen2-1.5B",
896
+ torch_dtype=torch.float16,
897
+ device_map="cpu",
898
+ trust_remote_code=True,
899
+ low_cpu_mem_usage=True
900
+ )
901
+ chat_pipeline = None
902
+ logger.info("Fallback model loaded successfully!")
903
+ return True
904
+ except Exception as e2:
905
+ logger.error(f"Fallback loading also failed: {str(e2)}")
906
+ return False
907
 
908
  def generate_response(prompt, temperature=0.7):
909
+ """Generate response using direct model generation (like the working app_gunicorn.py)"""
910
  try:
911
+ if model is None or tokenizer is None:
912
  return "Model is still loading, please wait a moment and try again..."
913
 
914
+ # Economics-focused system prompt (like the working version)
915
+ system_prompt = """You are AEGIS Economics AI, an expert economic analyst and policy advisor.
916
+ Provide clear, accurate, and insightful responses about economics, finance, markets, and policy.
917
+ Focus on practical analysis and actionable insights."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
 
919
+ full_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:"
920
+
921
+ # Tokenize input (like the working version)
922
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=1024)
923
+
924
+ # Generate response (like the working version)
925
+ with torch.no_grad():
926
+ outputs = model.generate(
927
+ inputs.input_ids,
928
+ max_new_tokens=256, # Same as working version
929
+ temperature=temperature,
930
+ do_sample=True,
931
+ pad_token_id=tokenizer.eos_token_id,
932
+ repetition_penalty=1.1,
933
+ no_repeat_ngram_size=3
934
+ )
 
 
 
 
 
 
935
 
936
+ # Decode response (like the working version)
937
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
938
+
939
+ # Extract only the assistant's response (like the working version)
940
+ if "Assistant:" in response:
941
+ response = response.split("Assistant:")[-1].strip()
942
+
943
+ return response
944
 
945
  except Exception as e:
946
  logger.error(f"Error generating response: {str(e)}")
947
+ return "I apologize, but I'm having trouble processing your request right now. Please try again in a moment."
 
 
 
948
 
949
  # HTML template (same as before)
950
  HTML_TEMPLATE = """
 
1287
 
1288
  return jsonify({
1289
  'success': success,
1290
+ 'model_loaded': model is not None,
1291
  'tokenizer_loaded': tokenizer is not None,
1292
  'message': 'Model loaded successfully' if success else 'Model loading failed'
1293
  })
 
1323
  """Health check endpoint"""
1324
  return jsonify({
1325
  'status': 'healthy',
1326
+ 'model_loaded': model is not None,
1327
  'tokenizer_loaded': tokenizer is not None,
1328
  'langgraph_available': LANGGRAPH_AVAILABLE,
1329
  'processing_mode': 'langgraph' if LANGGRAPH_AVAILABLE else 'simplified'
 
1358
  logger.info("Production mode: Loading model during module import...")
1359
  logger.info(f"LangGraph available: {LANGGRAPH_AVAILABLE}")
1360
 
1361
+ # Try to load model, but don't fail if it doesn't work (like the working version)
1362
+ logger.info("Attempting to load model...")
1363
  model_loaded = load_model()
1364
 
1365
  if model_loaded:
1366
  logger.info("✅ Model loaded successfully for production!")
1367
  else:
1368
+ logger.warning("⚠️ Model failed to load, but server will start anyway. Model can be loaded via /load_model_manual endpoint.")