astrosbd commited on
Commit
9f24e01
·
verified ·
1 Parent(s): 1ef82e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -57
app.py CHANGED
@@ -68,6 +68,40 @@ except Exception as e:
68
  # --------------------------------------------------------------------------------------
69
  # Basics
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Add current directory to path (for local modules if any)
73
  if os.getcwd() not in sys.path:
@@ -76,44 +110,8 @@ if os.getcwd() not in sys.path:
76
  # --------------------------------------------------------------------------------------
77
  # Hugging Face download (robust)
78
  # --------------------------------------------------------------------------------------
79
- try:
80
- from huggingface_hub import hf_hub_download, HfHubHTTPError
81
- except Exception:
82
- hf_hub_download = None
83
- HfHubHTTPError = Exception
84
-
85
- HF_TOKEN = os.getenv("key") # backwards compat
86
- # --------------------------------------------------------------------------------------
87
- print("DEBUG MODEL_REPO:", os.getenv("MODEL_REPO"))
88
- print("DEBUG MODEL_REPO_REVISION:", os.getenv("MODEL_REPO_REVISION"))
89
- print("DEBUG PRIVATE_REPO:", os.getenv("PRIVATE_REPO"))
90
- print("DEBUG HF token present:", bool(os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HF_TOKEN") or os.getenv("key")))
91
- import torch, timm
92
- print("DEBUG torch:", torch.__version__)
93
- print("DEBUG timm:", timm.__version__)
94
- PRIVATE_REPO = os.getenv("PRIVATE_REPO")
95
- huggingface_model_path = None
96
 
97
- def _try_hf_download(repo_id: str, filename: str, token: Optional[str]):
98
- if hf_hub_download is None:
99
- print("ℹ️ huggingface_hub not available; skipping HF download.")
100
- return None
101
- try:
102
- return hf_hub_download(repo_id=repo_id, filename=filename, token=token)
103
- except HfHubHTTPError as e:
104
- print(f"⚠️ HF HTTP error: {e}")
105
- except Exception as e:
106
- print(f"⚠️ HF generic error: {e}")
107
- return None
108
 
109
- if PRIVATE_REPO != "fallback":
110
- huggingface_model_path = _try_hf_download(PRIVATE_REPO, "V1.pkl", HF_TOKEN)
111
- if huggingface_model_path:
112
- print(f"✅ Model downloaded: {huggingface_model_path}")
113
- else:
114
- print("⚠️ Could not download V1.pkl. Will try local path or demo mode.")
115
- else:
116
- print("ℹ️ PRIVATE_REPO not set; will use local ./output/V1.pkl or demo mode.")
117
 
118
  # --------------------------------------------------------------------------------------
119
  # Paths / Devices
@@ -261,27 +259,7 @@ def run_damage_detection(pil_image: Image.Image, score_thresh: float = 0.5):
261
  # --------------------------------------------------------------------------------------
262
  # Stage 2: RADIO feature extractor + classifier
263
  # --------------------------------------------------------------------------------------
264
- def preload_models():
265
- """Preload RADIO model at startup to improve response time (idempotent)."""
266
- global image_processor, model, _preloaded
267
- if _preloaded and image_processor is not None and model is not None:
268
- return True
269
- try:
270
- hf_repo = os.getenv('MODEL_REPO')
271
- if not hf_repo:
272
- print("⚠️ MODEL_REPO not set → demo mode unless already loaded.")
273
- return False
274
- image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
275
- m = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
276
- m.to(DEVICE).eval()
277
- model = m
278
- _preloaded = True
279
- print("✅ RADIO model preloaded")
280
- return True
281
- except Exception as e:
282
- print(f"❌ RADIO preload failed: {e}")
283
- traceback.print_exc()
284
- return False
285
 
286
  def load_ai_detection_classifier(model_path):
287
  """Load the AI detection classifier (joblib)."""
 
68
  # --------------------------------------------------------------------------------------
69
  # Basics
70
 
71
+ # Initialize device for model
72
+ if torch.backends.mps.is_available():
73
+ RADIO_DEVICE = torch.device("mps")
74
+ elif torch.cuda.is_available():
75
+ RADIO_DEVICE = torch.device("cuda")
76
+ else:
77
+ RADIO_DEVICE = torch.device("cpu")
78
+
79
+ # Global variables for C model
80
+ radio_l_image_processor = None
81
+ radio_l_model = None
82
+ ai_detection_classifier = None
83
+
84
+
85
+ # Preload the C model at startup
86
+ def preload_models():
87
+ """Preload models at startup to improve response time"""
88
+ global radio_l_image_processor, radio_l_model
89
+
90
+ print("🔄 Preloading C model (4GB)...")
91
+ try:
92
+ hf_repo = os.getenv('MODEL_REPO', 'fallback')
93
+ if hf_repo and hf_repo != 'fallback':
94
+ from transformers import AutoModel, CLIPImageProcessor
95
+ radio_l_image_processor = CLIPImageProcessor.from_pretrained(hf_repo)
96
+ radio_l_model = AutoModel.from_pretrained(hf_repo, trust_remote_code=True)
97
+ radio_l_model = radio_l_model.to(RADIO_DEVICE)
98
+ radio_l_model.eval()
99
+ print("✅ C model preloaded successfully!")
100
+ return True
101
+ except Exception as e:
102
+ print(f"⚠️ Could not preload C model: {e}")
103
+ return False
104
+
105
 
106
  # Add current directory to path (for local modules if any)
107
  if os.getcwd() not in sys.path:
 
110
  # --------------------------------------------------------------------------------------
111
  # Hugging Face download (robust)
112
  # --------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
 
 
 
 
 
 
 
115
 
116
  # --------------------------------------------------------------------------------------
117
  # Paths / Devices
 
259
  # --------------------------------------------------------------------------------------
260
  # Stage 2: RADIO feature extractor + classifier
261
  # --------------------------------------------------------------------------------------
262
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  def load_ai_detection_classifier(model_path):
265
  """Load the AI detection classifier (joblib)."""