Enferlain commited on
Commit
faec259
·
verified ·
1 Parent(s): ffa73e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -137,17 +137,22 @@ class ModelManager:
137
  base_vision_model_name = self.config.get("base_vision_model")
138
  print(f"Loading vision model: {base_vision_model_name}")
139
 
140
- # --- UPDATED LOADING LOGIC ---
141
  is_dinov3_8bit = "dinov3" in base_vision_model_name and "8bit" in base_vision_model_name
142
 
 
 
 
 
 
 
143
  if is_dinov3_8bit:
144
- # Use your 8-bit model from the Hub
145
- self.hf_processor = AutoProcessor.from_pretrained("facebook/dinov3-base") # Processor is usually from the base model
146
  self.vision_model = AutoModel.from_pretrained(
147
- base_vision_model_name,
148
- load_in_8bit=True,
149
- trust_remote_code=True
150
  ).eval()
 
 
 
 
151
  else: # For SigLIP or other non-8bit models
152
  self.hf_processor = AutoProcessor.from_pretrained(base_vision_model_name)
153
  self.vision_model = AutoModel.from_pretrained(
@@ -239,10 +244,9 @@ def predict_anatomy_v3(image: Image.Image, model_name: str):
239
  return {"Error": str(e)}
240
 
241
  # --- Gradio Interface ---
242
- # (Unchanged)
243
  DESCRIPTION = """
244
  ## Lumi's Anatomy Flaw Classifier Demo ✨
245
- Select a model from the dropdown, then upload an image to classify its anatomy/structure.
246
  """
247
  EXAMPLE_DIR = "examples"
248
  examples = []
@@ -250,7 +254,20 @@ if os.path.isdir(EXAMPLE_DIR):
250
  examples = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
251
 
252
  default_model = list(MODEL_CATALOG.keys())[0]
253
- interface = gr.Interface(fn=predict_anatomy_v3, inputs=[gr.Image(type="pil", label="Input Image"), gr.Dropdown(choices=list(MODEL_CATALOG.keys()), value=default_model, label="Classifier Model")], outputs=gr.Label(label="Class Probabilities", num_top_classes=2), title="Lumi's Anatomy Classifier", description=DESCRIPTION, examples=examples if examples else None, allow_flagging="never", cache_examples=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  if __name__ == "__main__":
256
  try:
@@ -258,4 +275,5 @@ if __name__ == "__main__":
258
  model_manager.load_model(default_model)
259
  except Exception as e:
260
  print(f"WARNING: Could not pre-load default model. Error: {e}")
 
261
  interface.launch()
 
137
  base_vision_model_name = self.config.get("base_vision_model")
138
  print(f"Loading vision model: {base_vision_model_name}")
139
 
 
140
  is_dinov3_8bit = "dinov3" in base_vision_model_name and "8bit" in base_vision_model_name
141
 
142
+ # --- UPDATED LOGIC v5 ---
143
+ # For 8-bit, the repo contains everything, including the processor.
144
+ # For others, we load from their base name.
145
+ processor_name = base_vision_model_name if is_dinov3_8bit else self.config.get("base_vision_model")
146
+ self.hf_processor = AutoProcessor.from_pretrained(processor_name, trust_remote_code=True) # <-- THE ONLY CHANGE IS HERE
147
+
148
  if is_dinov3_8bit:
 
 
149
  self.vision_model = AutoModel.from_pretrained(
150
+ base_vision_model_name, load_in_8bit=True, trust_remote_code=True
 
 
151
  ).eval()
152
+ else:
153
+ self.vision_model = AutoModel.from_pretrained(
154
+ base_vision_model_name, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
155
+ ).to(DEVICE).eval()
156
  else: # For SigLIP or other non-8bit models
157
  self.hf_processor = AutoProcessor.from_pretrained(base_vision_model_name)
158
  self.vision_model = AutoModel.from_pretrained(
 
244
  return {"Error": str(e)}
245
 
246
  # --- Gradio Interface ---
 
247
  DESCRIPTION = """
248
  ## Lumi's Anatomy Flaw Classifier Demo ✨
249
+ Select a model from the dropdown, then upload an image to classify its anatomy/structural correctness.
250
  """
251
  EXAMPLE_DIR = "examples"
252
  examples = []
 
254
  examples = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
255
 
256
  default_model = list(MODEL_CATALOG.keys())[0]
257
+
258
+ interface = gr.Interface(
259
+ fn=predict_anatomy_v3,
260
+ inputs=[
261
+ gr.Image(type="pil", label="Input Image"),
262
+ gr.Dropdown(choices=list(MODEL_CATALOG.keys()), value=default_model, label="Classifier Model")
263
+ ],
264
+ outputs=gr.Label(label="Class Probabilities", num_top_classes=2),
265
+ title="Lumi's Anatomy Classifier",
266
+ description=DESCRIPTION,
267
+ examples=examples if examples else None,
268
+ allow_flagging="never",
269
+ cache_examples=True
270
+ )
271
 
272
  if __name__ == "__main__":
273
  try:
 
275
  model_manager.load_model(default_model)
276
  except Exception as e:
277
  print(f"WARNING: Could not pre-load default model. Error: {e}")
278
+
279
  interface.launch()