Gaston895 commited on
Commit
756e842
·
verified ·
1 Parent(s): 77bf462

🚀 Memory and speed optimizations: faster generation, better memory management

Browse files
Files changed (1) hide show
  1. app.py +114 -25
app.py CHANGED
@@ -41,6 +41,18 @@ tokenizer = None
41
  chat_pipeline = None
42
  executor = ThreadPoolExecutor(max_workers=2)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @dataclass
45
  class TechScores:
46
  """Technology threat scores structure"""
@@ -836,7 +848,7 @@ Focus on quantitative metrics and actionable insights."""
836
  langgraph_processor = LangGraphProcessor()
837
 
838
  def load_model():
839
- """Load the model and tokenizer from Gaston895/Aegisecon1 repository using pipeline approach"""
840
  global model, tokenizer, chat_pipeline
841
 
842
  try:
@@ -847,68 +859,129 @@ def load_model():
847
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
848
 
849
  logger.info(f"Loading tokenizer from {MODEL_NAME}...")
850
- # Load tokenizer
851
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
 
 
 
 
 
 
 
852
 
853
  logger.info(f"Loading model from {MODEL_NAME}...")
854
- # Load model with appropriate settings
855
  model = AutoModelForCausalLM.from_pretrained(
856
  MODEL_NAME,
857
- torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
858
  device_map="auto" if DEVICE == "cuda" else None,
859
  trust_remote_code=True,
860
- low_cpu_mem_usage=True
 
 
861
  )
862
 
863
- # Create pipeline
 
 
 
 
 
864
  chat_pipeline = pipeline(
865
  "text-generation",
866
  model=model,
867
  tokenizer=tokenizer,
868
  device=0 if DEVICE == "cuda" else -1,
869
- torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32
 
 
 
 
 
870
  )
871
 
872
  logger.info("Model loaded successfully from HF repository!")
 
 
 
 
 
 
 
873
  return True
874
 
875
  except Exception as e:
876
  logger.error(f"Failed to load model: {str(e)}")
 
 
 
877
  return False
878
 
879
  def generate_response(prompt, temperature=0.7):
880
- """Generate response using the loaded model pipeline"""
881
  try:
882
  if not chat_pipeline:
883
  return "Model is still loading, please wait a moment and try again..."
884
 
885
- # Format the prompt
 
 
 
 
 
886
  formatted_prompt = f"User: {prompt}\nAssistant:"
887
 
888
- # Generate response
889
  response = chat_pipeline(
890
  formatted_prompt,
891
- max_new_tokens=256, # Use only max_new_tokens to avoid conflict
892
  temperature=temperature,
893
  do_sample=True,
 
 
894
  pad_token_id=tokenizer.eos_token_id,
895
- truncation=True
 
 
 
896
  )
897
 
898
- # Extract the generated text
899
- generated_text = response[0]['generated_text']
900
-
901
- # Extract only the assistant's response
902
- if "Assistant:" in generated_text:
903
- assistant_response = generated_text.split("Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
904
  else:
905
- assistant_response = generated_text.replace(formatted_prompt, "").strip()
906
 
907
- return assistant_response
 
 
 
 
908
 
909
  except Exception as e:
910
  logger.error(f"Error generating response: {str(e)}")
911
- return f"Error: {str(e)}"
 
 
 
912
 
913
  # HTML template (same as before)
914
  HTML_TEMPLATE = """
@@ -1157,7 +1230,7 @@ def home():
1157
 
1158
  @app.route('/process_tech_scores', methods=['POST'])
1159
  def process_tech_scores():
1160
- """Process technology scores through LangGraph pipeline"""
1161
  try:
1162
  data = request.get_json()
1163
 
@@ -1174,8 +1247,16 @@ def process_tech_scores():
1174
 
1175
  logger.info(f"Processing tech scores: {tech_scores.to_dict()}")
1176
 
1177
- # Process through LangGraph
1178
- langgraph_result = langgraph_processor.process_tech_scores(tech_scores)
 
 
 
 
 
 
 
 
1179
 
1180
  if not langgraph_result['success']:
1181
  return jsonify({'success': False, 'error': 'LangGraph processing failed'})
@@ -1183,10 +1264,17 @@ def process_tech_scores():
1183
  # Get the optimized prompt from LangGraph
1184
  final_prompt = langgraph_result['final_prompt']
1185
 
 
 
 
 
1186
  # Generate final analysis using AEGIS Economics AI
1187
  logger.info("Generating final analysis with AEGIS Economics AI...")
1188
  final_analysis = generate_response(final_prompt)
1189
 
 
 
 
1190
  return jsonify({
1191
  'success': True,
1192
  'processing_steps': langgraph_result.get('processing_steps', []),
@@ -1197,6 +1285,7 @@ def process_tech_scores():
1197
 
1198
  except Exception as e:
1199
  logger.error(f"Error in tech score processing: {str(e)}")
 
1200
  return jsonify({'success': False, 'error': str(e)}), 500
1201
 
1202
  @app.route('/chat', methods=['POST'])
 
41
  chat_pipeline = None
42
  executor = ThreadPoolExecutor(max_workers=2)
43
 
44
+ def cleanup_memory():
45
+ """Clean up GPU/CPU memory"""
46
+ try:
47
+ if torch.cuda.is_available():
48
+ torch.cuda.empty_cache()
49
+ torch.cuda.synchronize()
50
+ # Force garbage collection
51
+ import gc
52
+ gc.collect()
53
+ except Exception as e:
54
+ logger.warning(f"Memory cleanup warning: {e}")
55
+
56
  @dataclass
57
  class TechScores:
58
  """Technology threat scores structure"""
 
848
  langgraph_processor = LangGraphProcessor()
849
 
850
  def load_model():
851
+ """Load the model and tokenizer from Gaston895/Aegisecon1 repository using optimized pipeline approach"""
852
  global model, tokenizer, chat_pipeline
853
 
854
  try:
 
859
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
860
 
861
  logger.info(f"Loading tokenizer from {MODEL_NAME}...")
862
+ # Load tokenizer with optimizations
863
+ tokenizer = AutoTokenizer.from_pretrained(
864
+ MODEL_NAME,
865
+ trust_remote_code=True,
866
+ use_fast=True # Use fast tokenizer for speed
867
+ )
868
+
869
+ # Set pad token if not exists
870
+ if tokenizer.pad_token is None:
871
+ tokenizer.pad_token = tokenizer.eos_token
872
 
873
  logger.info(f"Loading model from {MODEL_NAME}...")
874
+ # Load model with memory-optimized settings
875
  model = AutoModelForCausalLM.from_pretrained(
876
  MODEL_NAME,
877
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Use float16 for memory efficiency
878
  device_map="auto" if DEVICE == "cuda" else None,
879
  trust_remote_code=True,
880
+ low_cpu_mem_usage=True,
881
+ use_cache=True, # Enable KV cache for faster generation
882
+ attn_implementation="flash_attention_2" if DEVICE == "cuda" else None # Use flash attention if available
883
  )
884
 
885
+ # Optimize model for inference
886
+ model.eval()
887
+ if hasattr(model, 'gradient_checkpointing_disable'):
888
+ model.gradient_checkpointing_disable()
889
+
890
+ # Create pipeline with optimized settings
891
  chat_pipeline = pipeline(
892
  "text-generation",
893
  model=model,
894
  tokenizer=tokenizer,
895
  device=0 if DEVICE == "cuda" else -1,
896
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
897
+ model_kwargs={
898
+ "use_cache": True,
899
+ "do_sample": True,
900
+ "pad_token_id": tokenizer.eos_token_id
901
+ }
902
  )
903
 
904
  logger.info("Model loaded successfully from HF repository!")
905
+ logger.info(f"Device: {DEVICE}")
906
+ logger.info(f"Model dtype: {next(model.parameters()).dtype}")
907
+
908
+ # Clear any initialization memory
909
+ if torch.cuda.is_available():
910
+ torch.cuda.empty_cache()
911
+
912
  return True
913
 
914
  except Exception as e:
915
  logger.error(f"Failed to load model: {str(e)}")
916
+ # Clear memory on failure
917
+ if torch.cuda.is_available():
918
+ torch.cuda.empty_cache()
919
  return False
920
 
921
  def generate_response(prompt, temperature=0.7):
922
+ """Generate response using the loaded model pipeline with optimizations"""
923
  try:
924
  if not chat_pipeline:
925
  return "Model is still loading, please wait a moment and try again..."
926
 
927
+ # Truncate input prompt if too long to save memory
928
+ max_input_length = 800
929
+ if len(prompt) > max_input_length:
930
+ prompt = prompt[:max_input_length] + "..."
931
+
932
+ # Format the prompt efficiently
933
  formatted_prompt = f"User: {prompt}\nAssistant:"
934
 
935
+ # Optimized generation parameters for speed and memory
936
  response = chat_pipeline(
937
  formatted_prompt,
938
+ max_new_tokens=128, # Reduced for faster generation
939
  temperature=temperature,
940
  do_sample=True,
941
+ top_p=0.9, # Nucleus sampling for better quality
942
+ top_k=50, # Limit vocabulary for speed
943
  pad_token_id=tokenizer.eos_token_id,
944
+ eos_token_id=tokenizer.eos_token_id,
945
+ truncation=True,
946
+ return_full_text=False, # Only return new tokens
947
+ clean_up_tokenization_spaces=True
948
  )
949
 
950
+ # Extract response efficiently
951
+ if response and len(response) > 0:
952
+ generated_text = response[0].get('generated_text', '')
953
+
954
+ # Clean up the response
955
+ if "Assistant:" in generated_text:
956
+ assistant_response = generated_text.split("Assistant:")[-1].strip()
957
+ else:
958
+ assistant_response = generated_text.replace(formatted_prompt, "").strip()
959
+
960
+ # Remove any remaining prompt artifacts
961
+ if assistant_response.startswith("User:"):
962
+ lines = assistant_response.split('\n')
963
+ assistant_response = '\n'.join([line for line in lines if not line.strip().startswith("User:")])
964
+
965
+ # Ensure response isn't empty
966
+ if not assistant_response.strip():
967
+ assistant_response = "I understand your question. Let me provide an economic analysis based on the available data."
968
+
969
+ return assistant_response.strip()
970
  else:
971
+ return "I'm processing your request. Please try again in a moment."
972
 
973
+ except torch.cuda.OutOfMemoryError:
974
+ # Handle CUDA OOM gracefully
975
+ logger.error("CUDA out of memory during generation")
976
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
977
+ return "System is under high load. Please try a shorter question."
978
 
979
  except Exception as e:
980
  logger.error(f"Error generating response: {str(e)}")
981
+ # Clear any potential memory issues
982
+ if torch.cuda.is_available():
983
+ torch.cuda.empty_cache()
984
+ return "I'm experiencing technical difficulties. Please try again shortly."
985
 
986
  # HTML template (same as before)
987
  HTML_TEMPLATE = """
 
1230
 
1231
  @app.route('/process_tech_scores', methods=['POST'])
1232
  def process_tech_scores():
1233
+ """Process technology scores through LangGraph pipeline with memory optimization"""
1234
  try:
1235
  data = request.get_json()
1236
 
 
1247
 
1248
  logger.info(f"Processing tech scores: {tech_scores.to_dict()}")
1249
 
1250
+ # Clean memory before processing
1251
+ cleanup_memory()
1252
+
1253
+ # Process through LangGraph with timeout
1254
+ try:
1255
+ langgraph_result = langgraph_processor.process_tech_scores(tech_scores)
1256
+ except Exception as e:
1257
+ logger.error(f"LangGraph processing failed: {e}")
1258
+ # Fallback to simplified processing
1259
+ langgraph_result = langgraph_processor._simplified_processing(tech_scores)
1260
 
1261
  if not langgraph_result['success']:
1262
  return jsonify({'success': False, 'error': 'LangGraph processing failed'})
 
1264
  # Get the optimized prompt from LangGraph
1265
  final_prompt = langgraph_result['final_prompt']
1266
 
1267
+ # Truncate prompt if too long to save memory
1268
+ if len(final_prompt) > 1000:
1269
+ final_prompt = final_prompt[:1000] + "... [truncated for efficiency]"
1270
+
1271
  # Generate final analysis using AEGIS Economics AI
1272
  logger.info("Generating final analysis with AEGIS Economics AI...")
1273
  final_analysis = generate_response(final_prompt)
1274
 
1275
+ # Clean memory after processing
1276
+ cleanup_memory()
1277
+
1278
  return jsonify({
1279
  'success': True,
1280
  'processing_steps': langgraph_result.get('processing_steps', []),
 
1285
 
1286
  except Exception as e:
1287
  logger.error(f"Error in tech score processing: {str(e)}")
1288
+ cleanup_memory() # Clean memory on error
1289
  return jsonify({'success': False, 'error': str(e)}), 500
1290
 
1291
  @app.route('/chat', methods=['POST'])