Gaston895 commited on
Commit
a9b927e
·
verified ·
1 Parent(s): bd9c577

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -46
app.py CHANGED
@@ -2,36 +2,50 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import re
 
 
5
 
6
  # Global variables for model and tokenizer
7
  model = None
8
  tokenizer = None
 
9
 
10
  def load_model():
11
- """Load the model and tokenizer"""
12
- global model, tokenizer
13
 
14
  try:
15
  print("Loading AEGIS Conduct Economic Analysis Model...")
16
 
17
- # Load tokenizer and model directly from root (no subfolder needed)
18
  tokenizer = AutoTokenizer.from_pretrained(
19
  "Gaston895/aegisconduct",
20
  trust_remote_code=True
21
  )
22
 
 
23
  model = AutoModelForCausalLM.from_pretrained(
24
  "Gaston895/aegisconduct",
25
- torch_dtype=torch.bfloat16,
26
  device_map="auto",
27
- trust_remote_code=True
 
 
 
28
  )
29
 
 
 
 
 
 
30
  print("Model loaded successfully!")
 
31
  return True
32
 
33
  except Exception as e:
34
  print(f"Error loading model: {e}")
 
35
  return False
36
 
37
  def format_response(text):
@@ -45,41 +59,45 @@ def format_response(text):
45
 
46
  return text
47
 
48
- def generate_response(message, history, temperature=0.7, max_tokens=512):
49
- """Generate response from the model"""
50
- global model, tokenizer
51
 
52
- if model is None or tokenizer is None:
53
- return "Model not loaded. Please wait for initialization to complete."
54
 
55
  try:
56
- # Build conversation context
57
  conversation = ""
58
- for user_msg, assistant_msg in history:
 
 
 
59
  conversation += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
60
 
61
  # Add current message
62
  conversation += f"User: {message}\nAssistant:"
63
 
64
- # Tokenize input
65
- inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=2048)
66
 
67
  # Move to device
68
  if torch.cuda.is_available():
69
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
70
 
71
- # Generate response
72
  with torch.no_grad():
73
  outputs = model.generate(
74
  **inputs,
75
  max_new_tokens=max_tokens,
76
  temperature=temperature,
77
  do_sample=True,
78
- top_p=0.95,
79
- top_k=40,
80
- repetition_penalty=1.05,
81
  pad_token_id=tokenizer.eos_token_id,
82
- eos_token_id=tokenizer.eos_token_id
 
83
  )
84
 
85
  # Decode response
@@ -91,6 +109,10 @@ def generate_response(message, history, temperature=0.7, max_tokens=512):
91
  # Format and clean response
92
  response = format_response(response)
93
 
 
 
 
 
94
  return response
95
 
96
  except Exception as e:
@@ -109,14 +131,8 @@ def chat_interface(message, history, temperature, max_tokens):
109
 
110
  return history, ""
111
 
112
- # Load model on startup
113
- print("Initializing AEGIS Conduct Chat Interface...")
114
- model_loaded = load_model()
115
-
116
  # Create Gradio interface
117
- with gr.Blocks(
118
- title="AEGIS Conduct - Economic Analysis Chat"
119
- ) as demo:
120
 
121
  gr.Markdown("""
122
  # 🤖 AEGIS Conduct - Economic Analysis Chat
@@ -127,49 +143,43 @@ with gr.Blocks(
127
  - **128k Context**: Extended memory for detailed conversations
128
 
129
  Ask questions about economics, finance, market analysis, policy impacts, and more!
130
- """)
131
 
132
- if not model_loaded:
133
- gr.Markdown("⚠️ **Model Loading Error**: Please refresh the page or contact support.")
134
 
135
  with gr.Row():
136
  with gr.Column(scale=4):
137
  chatbot = gr.Chatbot(
138
- height=500,
139
- show_label=False,
140
- container=True
141
  )
142
 
143
  msg = gr.Textbox(
144
  placeholder="Ask me about economics, finance, markets, or any analytical question...",
145
- show_label=False,
146
- container=False,
147
- scale=7
148
  )
149
 
150
  with gr.Row():
151
- submit_btn = gr.Button("Send", variant="primary", scale=1)
152
- clear_btn = gr.Button("Clear Chat", scale=1)
153
 
154
  with gr.Column(scale=1):
155
  gr.Markdown("### Settings")
156
 
157
  temperature = gr.Slider(
158
  minimum=0.1,
159
- maximum=2.0,
160
  value=0.7,
161
  step=0.1,
162
- label="Temperature",
163
- info="Controls randomness (0.1=focused, 2.0=creative)"
164
  )
165
 
166
  max_tokens = gr.Slider(
167
  minimum=50,
168
- maximum=1024,
169
- value=512,
170
  step=50,
171
- label="Max Response Length",
172
- info="Maximum tokens in response"
173
  )
174
 
175
  gr.Markdown("""
@@ -179,6 +189,11 @@ with gr.Blocks(
179
  - What are the risks of high national debt?
180
  - How do interest rates affect the stock market?
181
  - Think deeply: What causes economic recessions?
 
 
 
 
 
182
  """)
183
 
184
  # Event handlers
@@ -186,6 +201,10 @@ with gr.Blocks(
186
  return chat_interface(message, history, temp, max_tok)
187
 
188
  def clear_chat():
 
 
 
 
189
  return [], ""
190
 
191
  # Bind events
@@ -206,11 +225,14 @@ with gr.Blocks(
206
  outputs=[chatbot, msg]
207
  )
208
 
 
 
 
 
209
  # Launch configuration
210
  if __name__ == "__main__":
211
  demo.launch(
212
  server_name="0.0.0.0",
213
  server_port=7860,
214
- share=False,
215
- theme=gr.themes.Soft()
216
  )
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import re
5
+ import gc
6
+ import os
7
 
8
  # Global variables for model and tokenizer
9
  model = None
10
  tokenizer = None
11
+ model_loaded = False
12
 
13
  def load_model():
14
+ """Load the model and tokenizer with memory optimization"""
15
+ global model, tokenizer, model_loaded
16
 
17
  try:
18
  print("Loading AEGIS Conduct Economic Analysis Model...")
19
 
20
+ # Load tokenizer first
21
  tokenizer = AutoTokenizer.from_pretrained(
22
  "Gaston895/aegisconduct",
23
  trust_remote_code=True
24
  )
25
 
26
+ # Load model with aggressive memory optimization
27
  model = AutoModelForCausalLM.from_pretrained(
28
  "Gaston895/aegisconduct",
29
+ torch_dtype=torch.float16,
30
  device_map="auto",
31
+ trust_remote_code=True,
32
+ low_cpu_mem_usage=True,
33
+ load_in_8bit=True,
34
+ max_memory={0: "6GB", "cpu": "8GB"} # Limit GPU and CPU memory usage
35
  )
36
 
37
+ # Force garbage collection
38
+ gc.collect()
39
+ if torch.cuda.is_available():
40
+ torch.cuda.empty_cache()
41
+
42
  print("Model loaded successfully!")
43
+ model_loaded = True
44
  return True
45
 
46
  except Exception as e:
47
  print(f"Error loading model: {e}")
48
+ model_loaded = False
49
  return False
50
 
51
  def format_response(text):
 
59
 
60
  return text
61
 
62
+ def generate_response(message, history, temperature=0.7, max_tokens=256):
63
+ """Generate response from the model with memory optimization"""
64
+ global model, tokenizer, model_loaded
65
 
66
+ if not model_loaded or model is None or tokenizer is None:
67
+ return "Model is loading... Please wait a moment and try again."
68
 
69
  try:
70
+ # Build conversation context (keep it shorter for memory)
71
  conversation = ""
72
+ # Only use last 3 exchanges to save memory
73
+ recent_history = history[-3:] if len(history) > 3 else history
74
+
75
+ for user_msg, assistant_msg in recent_history:
76
  conversation += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
77
 
78
  # Add current message
79
  conversation += f"User: {message}\nAssistant:"
80
 
81
+ # Tokenize input with length limit
82
+ inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024)
83
 
84
  # Move to device
85
  if torch.cuda.is_available():
86
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
87
 
88
+ # Generate response with memory-efficient settings
89
  with torch.no_grad():
90
  outputs = model.generate(
91
  **inputs,
92
  max_new_tokens=max_tokens,
93
  temperature=temperature,
94
  do_sample=True,
95
+ top_p=0.9,
96
+ top_k=50,
97
+ repetition_penalty=1.1,
98
  pad_token_id=tokenizer.eos_token_id,
99
+ eos_token_id=tokenizer.eos_token_id,
100
+ use_cache=True
101
  )
102
 
103
  # Decode response
 
109
  # Format and clean response
110
  response = format_response(response)
111
 
112
+ # Clean up GPU memory after generation
113
+ if torch.cuda.is_available():
114
+ torch.cuda.empty_cache()
115
+
116
  return response
117
 
118
  except Exception as e:
 
131
 
132
  return history, ""
133
 
 
 
 
 
134
  # Create Gradio interface
135
+ with gr.Blocks(title="AEGIS Conduct - Economic Analysis Chat") as demo:
 
 
136
 
137
  gr.Markdown("""
138
  # 🤖 AEGIS Conduct - Economic Analysis Chat
 
143
  - **128k Context**: Extended memory for detailed conversations
144
 
145
  Ask questions about economics, finance, market analysis, policy impacts, and more!
 
146
 
147
+ **Note**: This is a memory-optimized version for better performance.
148
+ """)
149
 
150
  with gr.Row():
151
  with gr.Column(scale=4):
152
  chatbot = gr.Chatbot(
153
+ height=400,
154
+ show_label=False
 
155
  )
156
 
157
  msg = gr.Textbox(
158
  placeholder="Ask me about economics, finance, markets, or any analytical question...",
159
+ show_label=False
 
 
160
  )
161
 
162
  with gr.Row():
163
+ submit_btn = gr.Button("Send", variant="primary")
164
+ clear_btn = gr.Button("Clear Chat")
165
 
166
  with gr.Column(scale=1):
167
  gr.Markdown("### Settings")
168
 
169
  temperature = gr.Slider(
170
  minimum=0.1,
171
+ maximum=1.5,
172
  value=0.7,
173
  step=0.1,
174
+ label="Temperature"
 
175
  )
176
 
177
  max_tokens = gr.Slider(
178
  minimum=50,
179
+ maximum=512,
180
+ value=256,
181
  step=50,
182
+ label="Max Response Length"
 
183
  )
184
 
185
  gr.Markdown("""
 
189
  - What are the risks of high national debt?
190
  - How do interest rates affect the stock market?
191
  - Think deeply: What causes economic recessions?
192
+
193
+ ### Memory Optimization
194
+ - Responses are limited to 256 tokens by default
195
+ - Only recent conversation history is used
196
+ - Model uses 8-bit quantization for efficiency
197
  """)
198
 
199
  # Event handlers
 
201
  return chat_interface(message, history, temp, max_tok)
202
 
203
  def clear_chat():
204
+ # Force garbage collection when clearing
205
+ gc.collect()
206
+ if torch.cuda.is_available():
207
+ torch.cuda.empty_cache()
208
  return [], ""
209
 
210
  # Bind events
 
225
  outputs=[chatbot, msg]
226
  )
227
 
228
+ # Load model on startup
229
+ print("Initializing AEGIS Conduct Chat Interface...")
230
+ load_model()
231
+
232
  # Launch configuration
233
  if __name__ == "__main__":
234
  demo.launch(
235
  server_name="0.0.0.0",
236
  server_port=7860,
237
+ share=False
 
238
  )