Add progress indicators and model caching
Browse files
app.py
CHANGED
|
@@ -103,11 +103,12 @@ def on_person_change(person_id):
|
|
| 103 |
return context_info, phrases_text, topics
|
| 104 |
|
| 105 |
|
| 106 |
-
def change_model(model_name):
|
| 107 |
"""Change the language model used for generation.
|
| 108 |
|
| 109 |
Args:
|
| 110 |
model_name: The name of the model to use
|
|
|
|
| 111 |
|
| 112 |
Returns:
|
| 113 |
A status message about the model change
|
|
@@ -120,12 +121,17 @@ def change_model(model_name):
|
|
| 120 |
if model_name == suggestion_generator.model_name:
|
| 121 |
return f"Already using model: {model_name}"
|
| 122 |
|
|
|
|
|
|
|
|
|
|
| 123 |
# Try to load the new model
|
| 124 |
success = suggestion_generator.load_model(model_name)
|
| 125 |
|
| 126 |
if success:
|
|
|
|
| 127 |
return f"Successfully switched to model: {model_name}"
|
| 128 |
else:
|
|
|
|
| 129 |
return f"Failed to load model: {model_name}. Using fallback responses instead."
|
| 130 |
|
| 131 |
|
|
@@ -136,6 +142,7 @@ def generate_suggestions(
|
|
| 136 |
selected_topic=None,
|
| 137 |
model_name="distilgpt2",
|
| 138 |
temperature=0.7,
|
|
|
|
| 139 |
):
|
| 140 |
"""Generate suggestions based on the selected person and user input."""
|
| 141 |
print(
|
|
@@ -144,13 +151,17 @@ def generate_suggestions(
|
|
| 144 |
f"model={model_name}, temperature={temperature}"
|
| 145 |
)
|
| 146 |
|
|
|
|
|
|
|
|
|
|
| 147 |
if not person_id:
|
| 148 |
print("No person_id provided")
|
| 149 |
return "Please select who you're talking to first."
|
| 150 |
|
| 151 |
# Make sure we're using the right model
|
| 152 |
if model_name != suggestion_generator.model_name:
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
person_context = social_graph.get_person_context(person_id)
|
| 156 |
print(f"Person context: {person_context}")
|
|
@@ -206,9 +217,13 @@ def generate_suggestions(
|
|
| 206 |
# If suggestion type is "model", use the language model for multiple suggestions
|
| 207 |
if suggestion_type == "model":
|
| 208 |
print("Using model for suggestions")
|
|
|
|
|
|
|
| 209 |
# Generate 3 different suggestions
|
| 210 |
suggestions = []
|
| 211 |
for i in range(3):
|
|
|
|
|
|
|
| 212 |
print(f"Generating suggestion {i+1}/3")
|
| 213 |
try:
|
| 214 |
suggestion = suggestion_generator.generate_suggestion(
|
|
@@ -247,9 +262,14 @@ def generate_suggestions(
|
|
| 247 |
else:
|
| 248 |
print("No category inferred, falling back to model")
|
| 249 |
# Fall back to model if we couldn't infer a category
|
|
|
|
| 250 |
try:
|
| 251 |
suggestions = []
|
| 252 |
for i in range(3):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
suggestion = suggestion_generator.generate_suggestion(
|
| 254 |
person_context, user_input, temperature=temperature
|
| 255 |
)
|
|
@@ -278,6 +298,10 @@ def generate_suggestions(
|
|
| 278 |
result = "No suggestions available. Please try a different option."
|
| 279 |
|
| 280 |
print(f"Returning result: {result[:100]}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
return result
|
| 282 |
|
| 283 |
|
|
|
|
| 103 |
return context_info, phrases_text, topics
|
| 104 |
|
| 105 |
|
| 106 |
+
def change_model(model_name, progress=gr.Progress()):
|
| 107 |
"""Change the language model used for generation.
|
| 108 |
|
| 109 |
Args:
|
| 110 |
model_name: The name of the model to use
|
| 111 |
+
progress: Gradio progress indicator
|
| 112 |
|
| 113 |
Returns:
|
| 114 |
A status message about the model change
|
|
|
|
| 121 |
if model_name == suggestion_generator.model_name:
|
| 122 |
return f"Already using model: {model_name}"
|
| 123 |
|
| 124 |
+
# Show progress indicator
|
| 125 |
+
progress(0, desc=f"Loading model: {model_name}")
|
| 126 |
+
|
| 127 |
# Try to load the new model
|
| 128 |
success = suggestion_generator.load_model(model_name)
|
| 129 |
|
| 130 |
if success:
|
| 131 |
+
progress(1.0, desc=f"Model loaded: {model_name}")
|
| 132 |
return f"Successfully switched to model: {model_name}"
|
| 133 |
else:
|
| 134 |
+
progress(1.0, desc="Model loading failed")
|
| 135 |
return f"Failed to load model: {model_name}. Using fallback responses instead."
|
| 136 |
|
| 137 |
|
|
|
|
| 142 |
selected_topic=None,
|
| 143 |
model_name="distilgpt2",
|
| 144 |
temperature=0.7,
|
| 145 |
+
progress=gr.Progress(),
|
| 146 |
):
|
| 147 |
"""Generate suggestions based on the selected person and user input."""
|
| 148 |
print(
|
|
|
|
| 151 |
f"model={model_name}, temperature={temperature}"
|
| 152 |
)
|
| 153 |
|
| 154 |
+
# Initialize progress
|
| 155 |
+
progress(0, desc="Starting...")
|
| 156 |
+
|
| 157 |
if not person_id:
|
| 158 |
print("No person_id provided")
|
| 159 |
return "Please select who you're talking to first."
|
| 160 |
|
| 161 |
# Make sure we're using the right model
|
| 162 |
if model_name != suggestion_generator.model_name:
|
| 163 |
+
progress(0.1, desc=f"Switching to model: {model_name}")
|
| 164 |
+
change_model(model_name, progress)
|
| 165 |
|
| 166 |
person_context = social_graph.get_person_context(person_id)
|
| 167 |
print(f"Person context: {person_context}")
|
|
|
|
| 217 |
# If suggestion type is "model", use the language model for multiple suggestions
|
| 218 |
if suggestion_type == "model":
|
| 219 |
print("Using model for suggestions")
|
| 220 |
+
progress(0.2, desc="Preparing to generate suggestions...")
|
| 221 |
+
|
| 222 |
# Generate 3 different suggestions
|
| 223 |
suggestions = []
|
| 224 |
for i in range(3):
|
| 225 |
+
progress_value = 0.3 + (i * 0.2) # Progress from 30% to 70%
|
| 226 |
+
progress(progress_value, desc=f"Generating suggestion {i+1}/3")
|
| 227 |
print(f"Generating suggestion {i+1}/3")
|
| 228 |
try:
|
| 229 |
suggestion = suggestion_generator.generate_suggestion(
|
|
|
|
| 262 |
else:
|
| 263 |
print("No category inferred, falling back to model")
|
| 264 |
# Fall back to model if we couldn't infer a category
|
| 265 |
+
progress(0.3, desc="No category detected, using model instead...")
|
| 266 |
try:
|
| 267 |
suggestions = []
|
| 268 |
for i in range(3):
|
| 269 |
+
progress_value = 0.4 + (i * 0.15) # Progress from 40% to 70%
|
| 270 |
+
progress(
|
| 271 |
+
progress_value, desc=f"Generating fallback suggestion {i+1}/3"
|
| 272 |
+
)
|
| 273 |
suggestion = suggestion_generator.generate_suggestion(
|
| 274 |
person_context, user_input, temperature=temperature
|
| 275 |
)
|
|
|
|
| 298 |
result = "No suggestions available. Please try a different option."
|
| 299 |
|
| 300 |
print(f"Returning result: {result[:100]}...")
|
| 301 |
+
|
| 302 |
+
# Complete the progress
|
| 303 |
+
progress(1.0, desc="Completed!")
|
| 304 |
+
|
| 305 |
return result
|
| 306 |
|
| 307 |
|
utils.py
CHANGED
|
@@ -161,6 +161,7 @@ class SuggestionGenerator:
|
|
| 161 |
self.model_loaded = False
|
| 162 |
self.generator = None
|
| 163 |
self.aac_user_info = None
|
|
|
|
| 164 |
|
| 165 |
# Load AAC user information from social graph
|
| 166 |
try:
|
|
@@ -196,6 +197,13 @@ class SuggestionGenerator:
|
|
| 196 |
self.model_name = model_name
|
| 197 |
self.model_loaded = False
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
try:
|
| 200 |
print(f"Loading model: {model_name}")
|
| 201 |
|
|
@@ -258,6 +266,9 @@ class SuggestionGenerator:
|
|
| 258 |
# For non-gated models, use the standard pipeline
|
| 259 |
self.generator = pipeline("text-generation", model=model_name)
|
| 260 |
|
|
|
|
|
|
|
|
|
|
| 261 |
self.model_loaded = True
|
| 262 |
print(f"Model loaded successfully: {model_name}")
|
| 263 |
return True
|
|
|
|
| 161 |
self.model_loaded = False
|
| 162 |
self.generator = None
|
| 163 |
self.aac_user_info = None
|
| 164 |
+
self.loaded_models = {} # Cache for loaded models
|
| 165 |
|
| 166 |
# Load AAC user information from social graph
|
| 167 |
try:
|
|
|
|
| 197 |
self.model_name = model_name
|
| 198 |
self.model_loaded = False
|
| 199 |
|
| 200 |
+
# Check if model is already loaded in cache
|
| 201 |
+
if model_name in self.loaded_models:
|
| 202 |
+
print(f"Using cached model: {model_name}")
|
| 203 |
+
self.generator = self.loaded_models[model_name]
|
| 204 |
+
self.model_loaded = True
|
| 205 |
+
return True
|
| 206 |
+
|
| 207 |
try:
|
| 208 |
print(f"Loading model: {model_name}")
|
| 209 |
|
|
|
|
| 266 |
# For non-gated models, use the standard pipeline
|
| 267 |
self.generator = pipeline("text-generation", model=model_name)
|
| 268 |
|
| 269 |
+
# Cache the loaded model
|
| 270 |
+
self.loaded_models[model_name] = self.generator
|
| 271 |
+
|
| 272 |
self.model_loaded = True
|
| 273 |
print(f"Model loaded successfully: {model_name}")
|
| 274 |
return True
|