🔧 Fix model loading using proven app_gunicorn.py approach - no pipeline, direct generation
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ CPU-optimized version for Modal deployment
|
|
| 7 |
|
| 8 |
from flask import Flask, request, jsonify, render_template_string
|
| 9 |
import torch
|
| 10 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 11 |
import os
|
| 12 |
import logging
|
| 13 |
import json
|
|
@@ -38,7 +38,6 @@ app = Flask(__name__)
|
|
| 38 |
# Global variables
|
| 39 |
model = None
|
| 40 |
tokenizer = None
|
| 41 |
-
chat_pipeline = None
|
| 42 |
executor = ThreadPoolExecutor(max_workers=2)
|
| 43 |
|
| 44 |
def cleanup_memory():
|
|
@@ -848,140 +847,104 @@ Focus on quantitative metrics and actionable insights."""
|
|
| 848 |
langgraph_processor = LangGraphProcessor()
|
| 849 |
|
| 850 |
def load_model():
|
| 851 |
-
"""Load the model and tokenizer from Gaston895/Aegisecon1 repository using
|
| 852 |
global model, tokenizer, chat_pipeline
|
| 853 |
|
| 854 |
try:
|
| 855 |
logger.info("Loading model and tokenizer from Hugging Face...")
|
| 856 |
|
| 857 |
-
#
|
| 858 |
-
|
| 859 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 860 |
|
| 861 |
-
logger.info(f"Loading tokenizer from {
|
| 862 |
-
# Load tokenizer with optimizations
|
| 863 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 864 |
-
|
| 865 |
trust_remote_code=True,
|
| 866 |
-
|
| 867 |
)
|
| 868 |
|
| 869 |
-
|
| 870 |
-
if tokenizer.pad_token is None:
|
| 871 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 872 |
-
|
| 873 |
-
logger.info(f"Loading model from {MODEL_NAME}...")
|
| 874 |
-
# Load model with memory-optimized settings
|
| 875 |
model = AutoModelForCausalLM.from_pretrained(
|
| 876 |
-
|
| 877 |
-
torch_dtype=torch.float16
|
| 878 |
-
device_map="
|
| 879 |
trust_remote_code=True,
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
attn_implementation="flash_attention_2" if DEVICE == "cuda" else None # Use flash attention if available
|
| 883 |
)
|
| 884 |
|
| 885 |
-
#
|
| 886 |
-
|
| 887 |
-
if hasattr(model, 'gradient_checkpointing_disable'):
|
| 888 |
-
model.gradient_checkpointing_disable()
|
| 889 |
-
|
| 890 |
-
# Create pipeline with optimized settings
|
| 891 |
-
chat_pipeline = pipeline(
|
| 892 |
-
"text-generation",
|
| 893 |
-
model=model,
|
| 894 |
-
tokenizer=tokenizer,
|
| 895 |
-
device=0 if DEVICE == "cuda" else -1,
|
| 896 |
-
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
| 897 |
-
model_kwargs={
|
| 898 |
-
"use_cache": True,
|
| 899 |
-
"do_sample": True,
|
| 900 |
-
"pad_token_id": tokenizer.eos_token_id
|
| 901 |
-
}
|
| 902 |
-
)
|
| 903 |
|
| 904 |
logger.info("Model loaded successfully from HF repository!")
|
| 905 |
-
logger.info(f"
|
| 906 |
logger.info(f"Model dtype: {next(model.parameters()).dtype}")
|
| 907 |
|
| 908 |
-
# Clear any initialization memory
|
| 909 |
-
if torch.cuda.is_available():
|
| 910 |
-
torch.cuda.empty_cache()
|
| 911 |
-
|
| 912 |
return True
|
| 913 |
|
| 914 |
except Exception as e:
|
| 915 |
-
logger.error(f"
|
| 916 |
-
#
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
|
| 921 |
def generate_response(prompt, temperature=0.7):
|
| 922 |
-
"""Generate response using
|
| 923 |
try:
|
| 924 |
-
if
|
| 925 |
return "Model is still loading, please wait a moment and try again..."
|
| 926 |
|
| 927 |
-
#
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
# Format the prompt efficiently
|
| 933 |
-
formatted_prompt = f"User: {prompt}\nAssistant:"
|
| 934 |
-
|
| 935 |
-
# Optimized generation parameters for speed and memory
|
| 936 |
-
response = chat_pipeline(
|
| 937 |
-
formatted_prompt,
|
| 938 |
-
max_new_tokens=128, # Reduced for faster generation
|
| 939 |
-
temperature=temperature,
|
| 940 |
-
do_sample=True,
|
| 941 |
-
top_p=0.9, # Nucleus sampling for better quality
|
| 942 |
-
top_k=50, # Limit vocabulary for speed
|
| 943 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 944 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 945 |
-
truncation=True,
|
| 946 |
-
return_full_text=False, # Only return new tokens
|
| 947 |
-
clean_up_tokenization_spaces=True
|
| 948 |
-
)
|
| 949 |
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
if not assistant_response.strip():
|
| 967 |
-
assistant_response = "I understand your question. Let me provide an economic analysis based on the available data."
|
| 968 |
-
|
| 969 |
-
return assistant_response.strip()
|
| 970 |
-
else:
|
| 971 |
-
return "I'm processing your request. Please try again in a moment."
|
| 972 |
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
|
|
|
|
|
|
|
|
|
| 978 |
|
| 979 |
except Exception as e:
|
| 980 |
logger.error(f"Error generating response: {str(e)}")
|
| 981 |
-
|
| 982 |
-
if torch.cuda.is_available():
|
| 983 |
-
torch.cuda.empty_cache()
|
| 984 |
-
return "I'm experiencing technical difficulties. Please try again shortly."
|
| 985 |
|
| 986 |
# HTML template (same as before)
|
| 987 |
HTML_TEMPLATE = """
|
|
@@ -1324,7 +1287,7 @@ def load_model_manual():
|
|
| 1324 |
|
| 1325 |
return jsonify({
|
| 1326 |
'success': success,
|
| 1327 |
-
'model_loaded':
|
| 1328 |
'tokenizer_loaded': tokenizer is not None,
|
| 1329 |
'message': 'Model loaded successfully' if success else 'Model loading failed'
|
| 1330 |
})
|
|
@@ -1360,7 +1323,7 @@ def health():
|
|
| 1360 |
"""Health check endpoint"""
|
| 1361 |
return jsonify({
|
| 1362 |
'status': 'healthy',
|
| 1363 |
-
'model_loaded':
|
| 1364 |
'tokenizer_loaded': tokenizer is not None,
|
| 1365 |
'langgraph_available': LANGGRAPH_AVAILABLE,
|
| 1366 |
'processing_mode': 'langgraph' if LANGGRAPH_AVAILABLE else 'simplified'
|
|
@@ -1395,11 +1358,11 @@ else:
|
|
| 1395 |
logger.info("Production mode: Loading model during module import...")
|
| 1396 |
logger.info(f"LangGraph available: {LANGGRAPH_AVAILABLE}")
|
| 1397 |
|
| 1398 |
-
#
|
| 1399 |
-
logger.info("
|
| 1400 |
model_loaded = load_model()
|
| 1401 |
|
| 1402 |
if model_loaded:
|
| 1403 |
logger.info("✅ Model loaded successfully for production!")
|
| 1404 |
else:
|
| 1405 |
-
logger.
|
|
|
|
| 7 |
|
| 8 |
from flask import Flask, request, jsonify, render_template_string
|
| 9 |
import torch
|
| 10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 11 |
import os
|
| 12 |
import logging
|
| 13 |
import json
|
|
|
|
| 38 |
# Global variables
|
| 39 |
model = None
|
| 40 |
tokenizer = None
|
|
|
|
| 41 |
executor = ThreadPoolExecutor(max_workers=2)
|
| 42 |
|
| 43 |
def cleanup_memory():
|
|
|
|
| 847 |
langgraph_processor = LangGraphProcessor()
|
| 848 |
|
| 849 |
def load_model():
|
| 850 |
+
"""Load the model and tokenizer from Gaston895/Aegisecon1 repository using the working approach"""
|
| 851 |
global model, tokenizer, chat_pipeline
|
| 852 |
|
| 853 |
try:
|
| 854 |
logger.info("Loading model and tokenizer from Hugging Face...")
|
| 855 |
|
| 856 |
+
# Load from the deployed model repository
|
| 857 |
+
model_repo = "Gaston895/Aegisecon1"
|
|
|
|
| 858 |
|
| 859 |
+
logger.info(f"Loading tokenizer from {model_repo}...")
|
|
|
|
| 860 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 861 |
+
model_repo,
|
| 862 |
trust_remote_code=True,
|
| 863 |
+
use_auth_token=False
|
| 864 |
)
|
| 865 |
|
| 866 |
+
logger.info(f"Loading model from {model_repo}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
model = AutoModelForCausalLM.from_pretrained(
|
| 868 |
+
model_repo,
|
| 869 |
+
torch_dtype=torch.float16, # Use float16 for better compatibility
|
| 870 |
+
device_map="cpu", # Force CPU for HF Spaces compatibility
|
| 871 |
trust_remote_code=True,
|
| 872 |
+
use_auth_token=False,
|
| 873 |
+
low_cpu_mem_usage=True
|
|
|
|
| 874 |
)
|
| 875 |
|
| 876 |
+
# Don't create pipeline - use direct model generation like the working version
|
| 877 |
+
chat_pipeline = None # Set to None to indicate we're using direct generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
logger.info("Model loaded successfully from HF repository!")
|
| 880 |
+
logger.info(f"Model device: {next(model.parameters()).device}")
|
| 881 |
logger.info(f"Model dtype: {next(model.parameters()).dtype}")
|
| 882 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
return True
|
| 884 |
|
| 885 |
except Exception as e:
|
| 886 |
+
logger.error(f"Error loading model from HF: {str(e)}")
|
| 887 |
+
# Try alternative loading method
|
| 888 |
+
try:
|
| 889 |
+
logger.info("Trying alternative loading method...")
|
| 890 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 891 |
+
"Qwen/Qwen2-1.5B", # Fallback to base model
|
| 892 |
+
trust_remote_code=True
|
| 893 |
+
)
|
| 894 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 895 |
+
"Qwen/Qwen2-1.5B",
|
| 896 |
+
torch_dtype=torch.float16,
|
| 897 |
+
device_map="cpu",
|
| 898 |
+
trust_remote_code=True,
|
| 899 |
+
low_cpu_mem_usage=True
|
| 900 |
+
)
|
| 901 |
+
chat_pipeline = None
|
| 902 |
+
logger.info("Fallback model loaded successfully!")
|
| 903 |
+
return True
|
| 904 |
+
except Exception as e2:
|
| 905 |
+
logger.error(f"Fallback loading also failed: {str(e2)}")
|
| 906 |
+
return False
|
| 907 |
|
| 908 |
def generate_response(prompt, temperature=0.7):
|
| 909 |
+
"""Generate response using direct model generation (like the working app_gunicorn.py)"""
|
| 910 |
try:
|
| 911 |
+
if model is None or tokenizer is None:
|
| 912 |
return "Model is still loading, please wait a moment and try again..."
|
| 913 |
|
| 914 |
+
# Economics-focused system prompt (like the working version)
|
| 915 |
+
system_prompt = """You are AEGIS Economics AI, an expert economic analyst and policy advisor.
|
| 916 |
+
Provide clear, accurate, and insightful responses about economics, finance, markets, and policy.
|
| 917 |
+
Focus on practical analysis and actionable insights."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
|
| 919 |
+
full_prompt = f"{system_prompt}\n\nUser: {prompt}\nAssistant:"
|
| 920 |
+
|
| 921 |
+
# Tokenize input (like the working version)
|
| 922 |
+
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=1024)
|
| 923 |
+
|
| 924 |
+
# Generate response (like the working version)
|
| 925 |
+
with torch.no_grad():
|
| 926 |
+
outputs = model.generate(
|
| 927 |
+
inputs.input_ids,
|
| 928 |
+
max_new_tokens=256, # Same as working version
|
| 929 |
+
temperature=temperature,
|
| 930 |
+
do_sample=True,
|
| 931 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 932 |
+
repetition_penalty=1.1,
|
| 933 |
+
no_repeat_ngram_size=3
|
| 934 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
|
| 936 |
+
# Decode response (like the working version)
|
| 937 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 938 |
+
|
| 939 |
+
# Extract only the assistant's response (like the working version)
|
| 940 |
+
if "Assistant:" in response:
|
| 941 |
+
response = response.split("Assistant:")[-1].strip()
|
| 942 |
+
|
| 943 |
+
return response
|
| 944 |
|
| 945 |
except Exception as e:
|
| 946 |
logger.error(f"Error generating response: {str(e)}")
|
| 947 |
+
return "I apologize, but I'm having trouble processing your request right now. Please try again in a moment."
|
|
|
|
|
|
|
|
|
|
| 948 |
|
| 949 |
# HTML template (same as before)
|
| 950 |
HTML_TEMPLATE = """
|
|
|
|
| 1287 |
|
| 1288 |
return jsonify({
|
| 1289 |
'success': success,
|
| 1290 |
+
'model_loaded': model is not None,
|
| 1291 |
'tokenizer_loaded': tokenizer is not None,
|
| 1292 |
'message': 'Model loaded successfully' if success else 'Model loading failed'
|
| 1293 |
})
|
|
|
|
| 1323 |
"""Health check endpoint"""
|
| 1324 |
return jsonify({
|
| 1325 |
'status': 'healthy',
|
| 1326 |
+
'model_loaded': model is not None,
|
| 1327 |
'tokenizer_loaded': tokenizer is not None,
|
| 1328 |
'langgraph_available': LANGGRAPH_AVAILABLE,
|
| 1329 |
'processing_mode': 'langgraph' if LANGGRAPH_AVAILABLE else 'simplified'
|
|
|
|
| 1358 |
logger.info("Production mode: Loading model during module import...")
|
| 1359 |
logger.info(f"LangGraph available: {LANGGRAPH_AVAILABLE}")
|
| 1360 |
|
| 1361 |
+
# Try to load model, but don't fail if it doesn't work (like the working version)
|
| 1362 |
+
logger.info("Attempting to load model...")
|
| 1363 |
model_loaded = load_model()
|
| 1364 |
|
| 1365 |
if model_loaded:
|
| 1366 |
logger.info("✅ Model loaded successfully for production!")
|
| 1367 |
else:
|
| 1368 |
+
logger.warning("⚠️ Model failed to load, but server will start anyway. Model can be loaded via /load_model_manual endpoint.")
|