🚀 Memory and speed optimizations: faster generation, better memory management
Browse files
app.py
CHANGED
|
@@ -41,6 +41,18 @@ tokenizer = None
|
|
| 41 |
chat_pipeline = None
|
| 42 |
executor = ThreadPoolExecutor(max_workers=2)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
@dataclass
|
| 45 |
class TechScores:
|
| 46 |
"""Technology threat scores structure"""
|
|
@@ -836,7 +848,7 @@ Focus on quantitative metrics and actionable insights."""
|
|
| 836 |
langgraph_processor = LangGraphProcessor()
|
| 837 |
|
| 838 |
def load_model():
|
| 839 |
-
"""Load the model and tokenizer from Gaston895/Aegisecon1 repository using pipeline approach"""
|
| 840 |
global model, tokenizer, chat_pipeline
|
| 841 |
|
| 842 |
try:
|
|
@@ -847,68 +859,129 @@ def load_model():
|
|
| 847 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 848 |
|
| 849 |
logger.info(f"Loading tokenizer from {MODEL_NAME}...")
|
| 850 |
-
# Load tokenizer
|
| 851 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
|
| 853 |
logger.info(f"Loading model from {MODEL_NAME}...")
|
| 854 |
-
# Load model with
|
| 855 |
model = AutoModelForCausalLM.from_pretrained(
|
| 856 |
MODEL_NAME,
|
| 857 |
-
torch_dtype=torch.
|
| 858 |
device_map="auto" if DEVICE == "cuda" else None,
|
| 859 |
trust_remote_code=True,
|
| 860 |
-
low_cpu_mem_usage=True
|
|
|
|
|
|
|
| 861 |
)
|
| 862 |
|
| 863 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
chat_pipeline = pipeline(
|
| 865 |
"text-generation",
|
| 866 |
model=model,
|
| 867 |
tokenizer=tokenizer,
|
| 868 |
device=0 if DEVICE == "cuda" else -1,
|
| 869 |
-
torch_dtype=torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
)
|
| 871 |
|
| 872 |
logger.info("Model loaded successfully from HF repository!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
return True
|
| 874 |
|
| 875 |
except Exception as e:
|
| 876 |
logger.error(f"Failed to load model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 877 |
return False
|
| 878 |
|
| 879 |
def generate_response(prompt, temperature=0.7):
|
| 880 |
-
"""Generate response using the loaded model pipeline"""
|
| 881 |
try:
|
| 882 |
if not chat_pipeline:
|
| 883 |
return "Model is still loading, please wait a moment and try again..."
|
| 884 |
|
| 885 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
formatted_prompt = f"User: {prompt}\nAssistant:"
|
| 887 |
|
| 888 |
-
#
|
| 889 |
response = chat_pipeline(
|
| 890 |
formatted_prompt,
|
| 891 |
-
max_new_tokens=
|
| 892 |
temperature=temperature,
|
| 893 |
do_sample=True,
|
|
|
|
|
|
|
| 894 |
pad_token_id=tokenizer.eos_token_id,
|
| 895 |
-
|
|
|
|
|
|
|
|
|
|
| 896 |
)
|
| 897 |
|
| 898 |
-
# Extract
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
else:
|
| 905 |
-
|
| 906 |
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
except Exception as e:
|
| 910 |
logger.error(f"Error generating response: {str(e)}")
|
| 911 |
-
|
|
|
|
|
|
|
|
|
|
| 912 |
|
| 913 |
# HTML template (same as before)
|
| 914 |
HTML_TEMPLATE = """
|
|
@@ -1157,7 +1230,7 @@ def home():
|
|
| 1157 |
|
| 1158 |
@app.route('/process_tech_scores', methods=['POST'])
|
| 1159 |
def process_tech_scores():
|
| 1160 |
-
"""Process technology scores through LangGraph pipeline"""
|
| 1161 |
try:
|
| 1162 |
data = request.get_json()
|
| 1163 |
|
|
@@ -1174,8 +1247,16 @@ def process_tech_scores():
|
|
| 1174 |
|
| 1175 |
logger.info(f"Processing tech scores: {tech_scores.to_dict()}")
|
| 1176 |
|
| 1177 |
-
#
|
| 1178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1179 |
|
| 1180 |
if not langgraph_result['success']:
|
| 1181 |
return jsonify({'success': False, 'error': 'LangGraph processing failed'})
|
|
@@ -1183,10 +1264,17 @@ def process_tech_scores():
|
|
| 1183 |
# Get the optimized prompt from LangGraph
|
| 1184 |
final_prompt = langgraph_result['final_prompt']
|
| 1185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1186 |
# Generate final analysis using AEGIS Economics AI
|
| 1187 |
logger.info("Generating final analysis with AEGIS Economics AI...")
|
| 1188 |
final_analysis = generate_response(final_prompt)
|
| 1189 |
|
|
|
|
|
|
|
|
|
|
| 1190 |
return jsonify({
|
| 1191 |
'success': True,
|
| 1192 |
'processing_steps': langgraph_result.get('processing_steps', []),
|
|
@@ -1197,6 +1285,7 @@ def process_tech_scores():
|
|
| 1197 |
|
| 1198 |
except Exception as e:
|
| 1199 |
logger.error(f"Error in tech score processing: {str(e)}")
|
|
|
|
| 1200 |
return jsonify({'success': False, 'error': str(e)}), 500
|
| 1201 |
|
| 1202 |
@app.route('/chat', methods=['POST'])
|
|
|
|
| 41 |
chat_pipeline = None
|
| 42 |
executor = ThreadPoolExecutor(max_workers=2)
|
| 43 |
|
| 44 |
+
def cleanup_memory():
|
| 45 |
+
"""Clean up GPU/CPU memory"""
|
| 46 |
+
try:
|
| 47 |
+
if torch.cuda.is_available():
|
| 48 |
+
torch.cuda.empty_cache()
|
| 49 |
+
torch.cuda.synchronize()
|
| 50 |
+
# Force garbage collection
|
| 51 |
+
import gc
|
| 52 |
+
gc.collect()
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.warning(f"Memory cleanup warning: {e}")
|
| 55 |
+
|
| 56 |
@dataclass
|
| 57 |
class TechScores:
|
| 58 |
"""Technology threat scores structure"""
|
|
|
|
| 848 |
langgraph_processor = LangGraphProcessor()
|
| 849 |
|
| 850 |
def load_model():
|
| 851 |
+
"""Load the model and tokenizer from Gaston895/Aegisecon1 repository using optimized pipeline approach"""
|
| 852 |
global model, tokenizer, chat_pipeline
|
| 853 |
|
| 854 |
try:
|
|
|
|
| 859 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 860 |
|
| 861 |
logger.info(f"Loading tokenizer from {MODEL_NAME}...")
|
| 862 |
+
# Load tokenizer with optimizations
|
| 863 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 864 |
+
MODEL_NAME,
|
| 865 |
+
trust_remote_code=True,
|
| 866 |
+
use_fast=True # Use fast tokenizer for speed
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# Set pad token if not exists
|
| 870 |
+
if tokenizer.pad_token is None:
|
| 871 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 872 |
|
| 873 |
logger.info(f"Loading model from {MODEL_NAME}...")
|
| 874 |
+
# Load model with memory-optimized settings
|
| 875 |
model = AutoModelForCausalLM.from_pretrained(
|
| 876 |
MODEL_NAME,
|
| 877 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Use float16 for memory efficiency
|
| 878 |
device_map="auto" if DEVICE == "cuda" else None,
|
| 879 |
trust_remote_code=True,
|
| 880 |
+
low_cpu_mem_usage=True,
|
| 881 |
+
use_cache=True, # Enable KV cache for faster generation
|
| 882 |
+
attn_implementation="flash_attention_2" if DEVICE == "cuda" else None # Use flash attention if available
|
| 883 |
)
|
| 884 |
|
| 885 |
+
# Optimize model for inference
|
| 886 |
+
model.eval()
|
| 887 |
+
if hasattr(model, 'gradient_checkpointing_disable'):
|
| 888 |
+
model.gradient_checkpointing_disable()
|
| 889 |
+
|
| 890 |
+
# Create pipeline with optimized settings
|
| 891 |
chat_pipeline = pipeline(
|
| 892 |
"text-generation",
|
| 893 |
model=model,
|
| 894 |
tokenizer=tokenizer,
|
| 895 |
device=0 if DEVICE == "cuda" else -1,
|
| 896 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
| 897 |
+
model_kwargs={
|
| 898 |
+
"use_cache": True,
|
| 899 |
+
"do_sample": True,
|
| 900 |
+
"pad_token_id": tokenizer.eos_token_id
|
| 901 |
+
}
|
| 902 |
)
|
| 903 |
|
| 904 |
logger.info("Model loaded successfully from HF repository!")
|
| 905 |
+
logger.info(f"Device: {DEVICE}")
|
| 906 |
+
logger.info(f"Model dtype: {next(model.parameters()).dtype}")
|
| 907 |
+
|
| 908 |
+
# Clear any initialization memory
|
| 909 |
+
if torch.cuda.is_available():
|
| 910 |
+
torch.cuda.empty_cache()
|
| 911 |
+
|
| 912 |
return True
|
| 913 |
|
| 914 |
except Exception as e:
|
| 915 |
logger.error(f"Failed to load model: {str(e)}")
|
| 916 |
+
# Clear memory on failure
|
| 917 |
+
if torch.cuda.is_available():
|
| 918 |
+
torch.cuda.empty_cache()
|
| 919 |
return False
|
| 920 |
|
| 921 |
def generate_response(prompt, temperature=0.7):
|
| 922 |
+
"""Generate response using the loaded model pipeline with optimizations"""
|
| 923 |
try:
|
| 924 |
if not chat_pipeline:
|
| 925 |
return "Model is still loading, please wait a moment and try again..."
|
| 926 |
|
| 927 |
+
# Truncate input prompt if too long to save memory
|
| 928 |
+
max_input_length = 800
|
| 929 |
+
if len(prompt) > max_input_length:
|
| 930 |
+
prompt = prompt[:max_input_length] + "..."
|
| 931 |
+
|
| 932 |
+
# Format the prompt efficiently
|
| 933 |
formatted_prompt = f"User: {prompt}\nAssistant:"
|
| 934 |
|
| 935 |
+
# Optimized generation parameters for speed and memory
|
| 936 |
response = chat_pipeline(
|
| 937 |
formatted_prompt,
|
| 938 |
+
max_new_tokens=128, # Reduced for faster generation
|
| 939 |
temperature=temperature,
|
| 940 |
do_sample=True,
|
| 941 |
+
top_p=0.9, # Nucleus sampling for better quality
|
| 942 |
+
top_k=50, # Limit vocabulary for speed
|
| 943 |
pad_token_id=tokenizer.eos_token_id,
|
| 944 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 945 |
+
truncation=True,
|
| 946 |
+
return_full_text=False, # Only return new tokens
|
| 947 |
+
clean_up_tokenization_spaces=True
|
| 948 |
)
|
| 949 |
|
| 950 |
+
# Extract response efficiently
|
| 951 |
+
if response and len(response) > 0:
|
| 952 |
+
generated_text = response[0].get('generated_text', '')
|
| 953 |
+
|
| 954 |
+
# Clean up the response
|
| 955 |
+
if "Assistant:" in generated_text:
|
| 956 |
+
assistant_response = generated_text.split("Assistant:")[-1].strip()
|
| 957 |
+
else:
|
| 958 |
+
assistant_response = generated_text.replace(formatted_prompt, "").strip()
|
| 959 |
+
|
| 960 |
+
# Remove any remaining prompt artifacts
|
| 961 |
+
if assistant_response.startswith("User:"):
|
| 962 |
+
lines = assistant_response.split('\n')
|
| 963 |
+
assistant_response = '\n'.join([line for line in lines if not line.strip().startswith("User:")])
|
| 964 |
+
|
| 965 |
+
# Ensure response isn't empty
|
| 966 |
+
if not assistant_response.strip():
|
| 967 |
+
assistant_response = "I understand your question. Let me provide an economic analysis based on the available data."
|
| 968 |
+
|
| 969 |
+
return assistant_response.strip()
|
| 970 |
else:
|
| 971 |
+
return "I'm processing your request. Please try again in a moment."
|
| 972 |
|
| 973 |
+
except torch.cuda.OutOfMemoryError:
|
| 974 |
+
# Handle CUDA OOM gracefully
|
| 975 |
+
logger.error("CUDA out of memory during generation")
|
| 976 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 977 |
+
return "System is under high load. Please try a shorter question."
|
| 978 |
|
| 979 |
except Exception as e:
|
| 980 |
logger.error(f"Error generating response: {str(e)}")
|
| 981 |
+
# Clear any potential memory issues
|
| 982 |
+
if torch.cuda.is_available():
|
| 983 |
+
torch.cuda.empty_cache()
|
| 984 |
+
return "I'm experiencing technical difficulties. Please try again shortly."
|
| 985 |
|
| 986 |
# HTML template (same as before)
|
| 987 |
HTML_TEMPLATE = """
|
|
|
|
| 1230 |
|
| 1231 |
@app.route('/process_tech_scores', methods=['POST'])
|
| 1232 |
def process_tech_scores():
|
| 1233 |
+
"""Process technology scores through LangGraph pipeline with memory optimization"""
|
| 1234 |
try:
|
| 1235 |
data = request.get_json()
|
| 1236 |
|
|
|
|
| 1247 |
|
| 1248 |
logger.info(f"Processing tech scores: {tech_scores.to_dict()}")
|
| 1249 |
|
| 1250 |
+
# Clean memory before processing
|
| 1251 |
+
cleanup_memory()
|
| 1252 |
+
|
| 1253 |
+
# Process through LangGraph with timeout
|
| 1254 |
+
try:
|
| 1255 |
+
langgraph_result = langgraph_processor.process_tech_scores(tech_scores)
|
| 1256 |
+
except Exception as e:
|
| 1257 |
+
logger.error(f"LangGraph processing failed: {e}")
|
| 1258 |
+
# Fallback to simplified processing
|
| 1259 |
+
langgraph_result = langgraph_processor._simplified_processing(tech_scores)
|
| 1260 |
|
| 1261 |
if not langgraph_result['success']:
|
| 1262 |
return jsonify({'success': False, 'error': 'LangGraph processing failed'})
|
|
|
|
| 1264 |
# Get the optimized prompt from LangGraph
|
| 1265 |
final_prompt = langgraph_result['final_prompt']
|
| 1266 |
|
| 1267 |
+
# Truncate prompt if too long to save memory
|
| 1268 |
+
if len(final_prompt) > 1000:
|
| 1269 |
+
final_prompt = final_prompt[:1000] + "... [truncated for efficiency]"
|
| 1270 |
+
|
| 1271 |
# Generate final analysis using AEGIS Economics AI
|
| 1272 |
logger.info("Generating final analysis with AEGIS Economics AI...")
|
| 1273 |
final_analysis = generate_response(final_prompt)
|
| 1274 |
|
| 1275 |
+
# Clean memory after processing
|
| 1276 |
+
cleanup_memory()
|
| 1277 |
+
|
| 1278 |
return jsonify({
|
| 1279 |
'success': True,
|
| 1280 |
'processing_steps': langgraph_result.get('processing_steps', []),
|
|
|
|
| 1285 |
|
| 1286 |
except Exception as e:
|
| 1287 |
logger.error(f"Error in tech score processing: {str(e)}")
|
| 1288 |
+
cleanup_memory() # Clean memory on error
|
| 1289 |
return jsonify({'success': False, 'error': str(e)}), 500
|
| 1290 |
|
| 1291 |
@app.route('/chat', methods=['POST'])
|