musictimer commited on
Commit
bbfa773
·
1 Parent(s): 93dbff3
Files changed (1) hide show
  1. app.py +77 -16
app.py CHANGED
@@ -95,49 +95,110 @@ class WebGameEngine:
95
 
96
  def load_model_weights():
97
  """Load model weights in thread pool to avoid blocking"""
 
 
 
98
  try:
99
- # Use torch.hub.load_state_dict_from_url which is HF Spaces compatible
100
- model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
101
- logger.info(f"Loading model from {model_url} using torch.hub...")
102
-
103
- # Update progress
104
- self.download_progress = 10
105
  self.loading_status = "Downloading model with torch.hub..."
 
106
 
107
- # Load state dict directly from URL (handles caching automatically)
108
  state_dict = torch.hub.load_state_dict_from_url(
109
  model_url,
110
  map_location=device,
111
- progress=True # Show download progress
 
112
  )
 
113
 
114
- self.download_progress = 80
115
- self.loading_status = "Loading model weights..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # Load each component of the agent using extract_state_dict (same as agent.load method)
118
  if any(k.startswith("denoiser") for k in state_dict.keys()):
119
  agent.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
 
 
 
 
120
  if any(k.startswith("upsampler") for k in state_dict.keys()) and agent.upsampler is not None:
121
  agent.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
 
 
 
 
122
  if any(k.startswith("rew_end_model") for k in state_dict.keys()) and agent.rew_end_model is not None:
123
  agent.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
 
 
 
 
124
  if any(k.startswith("actor_critic") for k in state_dict.keys()) and agent.actor_critic is not None:
125
  agent.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))
 
126
 
127
  self.download_progress = 100
128
  self.loading_status = "Model loaded successfully!"
 
129
  return True
130
 
131
  except Exception as e:
132
- logger.error(f"Failed to load model from URL: {e}")
 
 
133
  return False
134
 
135
- # Run in thread pool to avoid blocking
136
  loop = asyncio.get_event_loop()
137
- with concurrent.futures.ThreadPoolExecutor() as executor:
138
- success = await loop.run_in_executor(executor, load_model_weights)
139
-
140
- return success
 
 
 
 
 
 
 
 
 
 
141
 
142
  async def initialize_models(self):
143
  """Initialize the AI models and environment"""
 
95
 
96
  def load_model_weights():
97
  """Load model weights in thread pool to avoid blocking"""
98
+ state_dict = None
99
+
100
+ # Try torch.hub method first
101
  try:
102
+ logger.info("Trying to load model using torch.hub...")
 
 
 
 
 
103
  self.loading_status = "Downloading model with torch.hub..."
104
+ self.download_progress = 10
105
 
106
+ model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
107
  state_dict = torch.hub.load_state_dict_from_url(
108
  model_url,
109
  map_location=device,
110
+ progress=False,
111
+ check_hash=False
112
  )
113
+ logger.info("Successfully loaded model using torch.hub")
114
 
115
+ except Exception as e:
116
+ logger.warning(f"Failed to load model with torch.hub: {e}")
117
+
118
+ # Try huggingface_hub method as fallback
119
+ try:
120
+ logger.info("Trying to load model using huggingface_hub...")
121
+ self.loading_status = "Downloading model with huggingface_hub..."
122
+ self.download_progress = 10
123
+
124
+ from huggingface_hub import hf_hub_download
125
+
126
+ # Download the file
127
+ model_path = hf_hub_download(
128
+ repo_id="Etadingrui/diamond-1B",
129
+ filename="agent_epoch_00003.pt",
130
+ cache_dir=None # Use default cache
131
+ )
132
+ self.download_progress = 40
133
+ self.loading_status = "Loading downloaded model..."
134
+
135
+ # Load the state dict
136
+ state_dict = torch.load(model_path, map_location=device)
137
+ logger.info("Successfully loaded model using huggingface_hub")
138
+
139
+ except Exception as e2:
140
+ logger.error(f"Failed to load model with huggingface_hub: {e2}")
141
+ raise Exception("All model loading methods failed")
142
+
143
+ if state_dict is None:
144
+ raise Exception("Failed to load model state dict")
145
+
146
+ # Load state dict into agent
147
+ try:
148
+ logger.info("Model download completed, loading weights...")
149
+ self.download_progress = 60
150
+ self.loading_status = "Model downloaded, loading weights..."
151
 
152
  # Load each component of the agent using extract_state_dict (same as agent.load method)
153
  if any(k.startswith("denoiser") for k in state_dict.keys()):
154
  agent.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
155
+ logger.info("Loaded denoiser weights")
156
+
157
+ self.download_progress = 70
158
+ self.loading_status = "Loading upsampler..."
159
  if any(k.startswith("upsampler") for k in state_dict.keys()) and agent.upsampler is not None:
160
  agent.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
161
+ logger.info("Loaded upsampler weights")
162
+
163
+ self.download_progress = 80
164
+ self.loading_status = "Loading reward model..."
165
  if any(k.startswith("rew_end_model") for k in state_dict.keys()) and agent.rew_end_model is not None:
166
  agent.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
167
+ logger.info("Loaded reward model weights")
168
+
169
+ self.download_progress = 90
170
+ self.loading_status = "Loading actor critic..."
171
  if any(k.startswith("actor_critic") for k in state_dict.keys()) and agent.actor_critic is not None:
172
  agent.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))
173
+ logger.info("Loaded actor critic weights")
174
 
175
  self.download_progress = 100
176
  self.loading_status = "Model loaded successfully!"
177
+ logger.info("All model weights loaded successfully!")
178
  return True
179
 
180
  except Exception as e:
181
+ logger.error(f"Failed to load state dict into agent: {e}")
182
+ import traceback
183
+ traceback.print_exc()
184
  return False
185
 
186
+ # Run in thread pool to avoid blocking with timeout
187
  loop = asyncio.get_event_loop()
188
+ try:
189
+ with concurrent.futures.ThreadPoolExecutor() as executor:
190
+ # Add timeout for model loading (5 minutes max)
191
+ future = loop.run_in_executor(executor, load_model_weights)
192
+ success = await asyncio.wait_for(future, timeout=300.0) # 5 minute timeout
193
+ return success
194
+ except asyncio.TimeoutError:
195
+ logger.error("Model loading timed out after 5 minutes")
196
+ self.loading_status = "Model loading timed out - using dummy mode"
197
+ return False
198
+ except Exception as e:
199
+ logger.error(f"Error in model loading executor: {e}")
200
+ self.loading_status = f"Model loading error: {str(e)[:50]}..."
201
+ return False
202
 
203
  async def initialize_models(self):
204
  """Initialize the AI models and environment"""