santhoshv6 commited on
Commit
6f86b6f
Β·
1 Parent(s): 743a7b6

Optimize model loading to fix storage limit issues - use streaming and memory cleanup

Browse files
Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import requests
9
  from io import BytesIO
10
  import os
 
11
 
12
  # CIFAR-100 class names
13
  CIFAR100_CLASSES = [
@@ -91,10 +92,10 @@ model_loaded = False
91
  model_status = "Not loaded"
92
 
93
  def load_model_with_fallbacks():
94
- """Try multiple methods to load the model"""
95
  global model, model_loaded, model_status
96
 
97
- # Method 1: Try GitHub releases
98
  try:
99
  print("πŸ”„ Attempting to load model from GitHub releases...")
100
  model_url = "https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth"
@@ -104,19 +105,35 @@ def load_model_with_fallbacks():
104
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
105
  }
106
 
107
- response = requests.get(model_url, headers=headers, timeout=30)
 
108
  response.raise_for_status()
109
 
110
- print(f"βœ… Downloaded model: {len(response.content)} bytes")
 
 
 
111
 
112
- # Load the model state dict
113
- checkpoint = torch.load(BytesIO(response.content), map_location='cpu')
 
 
 
 
 
 
 
114
 
115
  if 'state_dict' in checkpoint:
116
  model.load_state_dict(checkpoint['state_dict'])
117
  model.eval()
 
 
 
 
 
 
118
  model_loaded = True
119
- accuracy = checkpoint.get('test_acc', 77.45) # Default to known accuracy
120
  model_status = f"βœ… Loaded from GitHub (Accuracy: {accuracy:.2f}%)"
121
  print(f"βœ… Model loaded successfully! Accuracy: {accuracy:.2f}%")
122
  return True
 
8
  import requests
9
  from io import BytesIO
10
  import os
11
+ import gc # For garbage collection
12
 
13
  # CIFAR-100 class names
14
  CIFAR100_CLASSES = [
 
92
  model_status = "Not loaded"
93
 
94
  def load_model_with_fallbacks():
95
+ """Try multiple methods to load the model with optimized memory usage"""
96
  global model, model_loaded, model_status
97
 
98
+ # Method 1: Try GitHub releases with streaming
99
  try:
100
  print("πŸ”„ Attempting to load model from GitHub releases...")
101
  model_url = "https://github.com/santhoshv6/era_v4_s8_assignment/releases/download/v1.0/model_best.pth"
 
105
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
106
  }
107
 
108
+ # Stream the download to avoid memory issues
109
+ response = requests.get(model_url, headers=headers, timeout=60, stream=True)
110
  response.raise_for_status()
111
 
112
+ # Load directly from stream to minimize memory usage
113
+ model_data = BytesIO()
114
+ for chunk in response.iter_content(chunk_size=8192):
115
+ model_data.write(chunk)
116
 
117
+ model_data.seek(0)
118
+ print(f"βœ… Downloaded model: {model_data.getbuffer().nbytes} bytes")
119
+
120
+ # Load the model state dict with memory optimization
121
+ checkpoint = torch.load(model_data, map_location='cpu')
122
+
123
+ # Clear the downloaded data immediately
124
+ model_data.close()
125
+ del model_data
126
 
127
  if 'state_dict' in checkpoint:
128
  model.load_state_dict(checkpoint['state_dict'])
129
  model.eval()
130
+
131
+ # Clear checkpoint data to free memory
132
+ accuracy = checkpoint.get('test_acc', 77.45)
133
+ del checkpoint
134
+ gc.collect() # Force garbage collection
135
+
136
  model_loaded = True
 
137
  model_status = f"βœ… Loaded from GitHub (Accuracy: {accuracy:.2f}%)"
138
  print(f"βœ… Model loaded successfully! Accuracy: {accuracy:.2f}%")
139
  return True