outsu commited on
Commit
5b43647
·
verified ·
1 Parent(s): 2ddd289

Added the feature to check and download models in Space from the actual model repository.

Browse files
Files changed (2) hide show
  1. app.py +33 -0
  2. requirements.txt +2 -2
app.py CHANGED
@@ -3,10 +3,12 @@ import gradio as gr
3
  import torch
4
  from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
5
  from transformers import BertTokenizerFast
 
6
 
7
  # Constants
8
  MODELS_DIR = "./TeLVE/models"
9
  TOKENIZER_PATH = "./TeLVE/tokenizer"
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  def list_available_models():
@@ -15,6 +17,35 @@ def list_available_models():
15
  return []
16
  return [f for f in os.listdir(MODELS_DIR) if f.endswith('.pth')]
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def generate_description(image, model_name):
19
  """Generate image caption using selected model"""
20
  try:
@@ -65,5 +96,7 @@ def create_interface():
65
  return interface
66
 
67
  if __name__ == "__main__":
 
 
68
  demo = create_interface()
69
  demo.launch(share=True, server_name="0.0.0.0")
 
3
  import torch
4
  from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
5
  from transformers import BertTokenizerFast
6
+ from huggingface_hub import hf_hub_download, list_repo_files
7
 
8
  # Constants
9
  MODELS_DIR = "./TeLVE/models"
10
  TOKENIZER_PATH = "./TeLVE/tokenizer"
11
+ HF_REPO_ID = "outsu/TeLVE"
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
  def list_available_models():
 
17
  return []
18
  return [f for f in os.listdir(MODELS_DIR) if f.endswith('.pth')]
19
 
20
+ def get_hf_model_list():
21
+ """Get list of model files from HuggingFace repository"""
22
+ try:
23
+ files = list_repo_files(HF_REPO_ID)
24
+ return [f.split('/')[-1] for f in files if f.startswith('models/') and f.endswith('.pth')]
25
+ except Exception as e:
26
+ print(f"Error fetching models from HuggingFace: {str(e)}")
27
+ return []
28
+
29
+ def download_missing_models():
30
+ """Download missing models from HuggingFace"""
31
+ if not os.path.exists(MODELS_DIR):
32
+ os.makedirs(MODELS_DIR)
33
+
34
+ local_models = set(list_available_models())
35
+ hf_models = set(get_hf_model_list())
36
+
37
+ for model in hf_models - local_models:
38
+ try:
39
+ print(f"Downloading missing model: {model}")
40
+ hf_hub_download(
41
+ repo_id=HF_REPO_ID,
42
+ filename=f"models/{model}",
43
+ local_dir=os.path.dirname(MODELS_DIR),
44
+ local_dir_use_symlinks=False
45
+ )
46
+ except Exception as e:
47
+ print(f"Error downloading {model}: {str(e)}")
48
+
49
  def generate_description(image, model_name):
50
  """Generate image caption using selected model"""
51
  try:
 
96
  return interface
97
 
98
  if __name__ == "__main__":
99
+ print("Checking for missing models...")
100
+ download_missing_models()
101
  demo = create_interface()
102
  demo.launch(share=True, server_name="0.0.0.0")
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
-
2
  torch>=1.9.0
3
  torchvision>=0.10.0
4
  transformers>=4.11.0
5
  gradio>=3.0.0
6
  pandas>=1.3.0
7
  Pillow>=8.0.0
8
- tqdm>=4.62.0
 
 
 
1
  torch>=1.9.0
2
  torchvision>=0.10.0
3
  transformers>=4.11.0
4
  gradio>=3.0.0
5
  pandas>=1.3.0
6
  Pillow>=8.0.0
7
+ tqdm>=4.62.0
8
+ huggingface-hub>=0.16.4