dejanseo commited on
Commit
95a0434
·
verified ·
1 Parent(s): f53584e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -14
app.py CHANGED
@@ -16,23 +16,79 @@ st.set_page_config(
16
 
17
  MODEL_ID = "dejanseo/query-grounding"
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- # --- FIX: avoid meta tensors, force CPU load first with full weights ---
22
- model = AutoModelForSequenceClassification.from_pretrained(
23
- MODEL_ID,
24
- token=HF_TOKEN,
25
- low_cpu_mem_usage=False, # ensure full materialization
26
- torch_dtype=torch.float32 # avoid meta tensors
27
- )
28
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
29
 
30
- # safe device move
31
- model.to(device)
32
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def classify(prompt: str):
35
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
 
 
 
 
 
 
 
 
 
36
  with torch.no_grad():
37
  logits = model(**inputs).logits
38
  probs = torch.softmax(logits, dim=-1).squeeze().cpu()
@@ -40,6 +96,7 @@ def classify(prompt: str):
40
  confidence = probs[pred].item()
41
  return pred, confidence
42
 
 
43
  # Font and style overrides
44
  st.markdown("""
45
  <style>
 
16
 
17
  MODEL_ID = "dejanseo/query-grounding"
18
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
19
 
20
+ PREFERRED_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+
23
+ def _has_meta_params(m: torch.nn.Module) -> bool:
24
+ for p in m.parameters():
25
+ if getattr(p, "is_meta", False):
26
+ return True
27
+ return False
28
+
29
+
30
+ def _first_real_param_device(m: torch.nn.Module) -> torch.device:
31
+ for p in m.parameters():
32
+ if not getattr(p, "is_meta", False):
33
+ return p.device
34
+ return torch.device("cpu")
35
+
36
+
37
+ @st.cache_resource(show_spinner=False)
38
+ def load_model_and_tokenizer():
39
+ # Attempt 1: normal full load (no meta), then move to preferred device
40
+ model = AutoModelForSequenceClassification.from_pretrained(
41
+ MODEL_ID,
42
+ token=HF_TOKEN,
43
+ low_cpu_mem_usage=False,
44
+ torch_dtype="auto",
45
+ )
46
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
47
+
48
+ # If anything is still meta, fallback to device_map loading (do NOT call .to() after that)
49
+ if _has_meta_params(model):
50
+ if torch.cuda.is_available():
51
+ model = AutoModelForSequenceClassification.from_pretrained(
52
+ MODEL_ID,
53
+ token=HF_TOKEN,
54
+ torch_dtype="auto",
55
+ device_map="auto",
56
+ )
57
+ else:
58
+ # CPU fallback retry without dtype hint
59
+ model = AutoModelForSequenceClassification.from_pretrained(
60
+ MODEL_ID,
61
+ token=HF_TOKEN,
62
+ low_cpu_mem_usage=False,
63
+ )
64
+
65
+ # Only call .to() if the model is not dispatched by Accelerate/device_map
66
+ if not hasattr(model, "hf_device_map"):
67
+ if _has_meta_params(model):
68
+ raise RuntimeError(
69
+ "Model parameters are still on the meta device after loading. "
70
+ "This is usually a torch/transformers/accelerate version or memory/offload issue."
71
+ )
72
+ model.to(PREFERRED_DEVICE)
73
+
74
+ model.eval()
75
+ return model, tokenizer
76
+
77
+
78
+ model, tokenizer = load_model_and_tokenizer()
79
+
80
 
81
  def classify(prompt: str):
82
+ exec_device = _first_real_param_device(model)
83
+ inputs = tokenizer(
84
+ prompt,
85
+ return_tensors="pt",
86
+ truncation=True,
87
+ padding=True,
88
+ max_length=512
89
+ )
90
+ inputs = {k: v.to(exec_device) for k, v in inputs.items()}
91
+
92
  with torch.no_grad():
93
  logits = model(**inputs).logits
94
  probs = torch.softmax(logits, dim=-1).squeeze().cpu()
 
96
  confidence = probs[pred].item()
97
  return pred, confidence
98
 
99
+
100
  # Font and style overrides
101
  st.markdown("""
102
  <style>