Shalmoni commited on
Commit
a4f9434
·
verified ·
1 Parent(s): 710b69a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -62,15 +62,25 @@ def _lazy_model_tok():
62
  global _tokenizer, _model
63
  if _tokenizer is not None and _model is not None:
64
  return _model, _tokenizer
 
65
  _tokenizer = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True)
 
 
 
 
 
66
  _model = AutoModelForCausalLM.from_pretrained(
67
  STORYBOARD_MODEL,
68
  device_map="auto",
69
- dtype="auto",
70
  trust_remote_code=True,
 
71
  )
 
 
72
  if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
73
  _tokenizer.pad_token_id = _tokenizer.eos_token_id
 
74
  return _model, _tokenizer
75
 
76
  def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str:
 
62
  global _tokenizer, _model
63
  if _tokenizer is not None and _model is not None:
64
  return _model, _tokenizer
65
+
66
  _tokenizer = AutoTokenizer.from_pretrained(STORYBOARD_MODEL, trust_remote_code=True)
67
+
68
+ # Choose a dtype that works both locally and on ZeroGPU
69
+ use_cuda = torch.cuda.is_available()
70
+ preferred_dtype = torch.float16 if use_cuda else torch.float32 # torch.bfloat16 is also fine if supported
71
+
72
  _model = AutoModelForCausalLM.from_pretrained(
73
  STORYBOARD_MODEL,
74
  device_map="auto",
75
+ torch_dtype=preferred_dtype, # ✅ FIXED: use torch_dtype
76
  trust_remote_code=True,
77
+ use_safetensors=True
78
  )
79
+
80
+ # Ensure pad token to avoid warnings
81
  if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
82
  _tokenizer.pad_token_id = _tokenizer.eos_token_id
83
+
84
  return _model, _tokenizer
85
 
86
  def _prompt_with_tags(user_prompt: str, n_shots: int, default_fps: int, default_len: int) -> str: