Pasipid791 commited on
Commit
bcd6ea3
·
verified ·
1 Parent(s): f314614

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import os
5
  import json
6
  import logging
@@ -11,23 +12,40 @@ logger = logging.getLogger(__name__)
11
 
12
  # Define model and checkpoint paths
13
  MODEL_REPO = "microsoft/CADFusion"
14
- CHECKPOINT_REVISION = "main" # Use 'main' branch, which contains v1_1 checkpoint
15
  CHECKPOINT_SUBFOLDER = "exp/model_ckpt/v1_1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Load model and tokenizer
18
  try:
19
  logger.info("Loading tokenizer...")
 
20
  tokenizer = AutoTokenizer.from_pretrained(
21
- MODEL_REPO,
22
- revision=CHECKPOINT_REVISION,
23
- subfolder=CHECKPOINT_SUBFOLDER,
24
  trust_remote_code=True
25
  )
26
  logger.info("Loading model...")
 
27
  model = AutoModelForCausalLM.from_pretrained(
28
- MODEL_REPO,
29
- revision=CHECKPOINT_REVISION,
30
- subfolder=CHECKPOINT_SUBFOLDER,
31
  torch_dtype=torch.float16,
32
  device_map="auto",
33
  trust_remote_code=True
@@ -35,7 +53,15 @@ try:
35
  logger.info("Model and tokenizer loaded successfully.")
36
  except Exception as e:
37
  logger.error(f"Error loading model or tokenizer: {str(e)}")
38
- raise e
 
 
 
 
 
 
 
 
39
 
40
  # Function to generate CAD model from text description
41
  def generate_cad_model(text_description):
@@ -100,8 +126,8 @@ def create_gradio_interface():
100
  gr.Markdown("""
101
  **Note**:
102
  - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
103
- - Ensure descriptions are clear and specific for best results.
104
- - For more details, visit the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion).
105
  """)
106
 
107
  return demo
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from huggingface_hub import hf_hub_download, snapshot_download
5
  import os
6
  import json
7
  import logging
 
12
 
13
  # Define model and checkpoint paths
14
  MODEL_REPO = "microsoft/CADFusion"
15
+ CHECKPOINT_REVISION = "main"
16
  CHECKPOINT_SUBFOLDER = "exp/model_ckpt/v1_1"
17
+ LOCAL_CHECKPOINT_DIR = "./model_ckpt/v1_1"
18
+
19
+ # Ensure local checkpoint directory exists
20
+ os.makedirs(LOCAL_CHECKPOINT_DIR, exist_ok=True)
21
+
22
+ # Download checkpoint files
23
+ try:
24
+ logger.info("Downloading checkpoint files...")
25
+ snapshot_download(
26
+ repo_id=MODEL_REPO,
27
+ revision=CHECKPOINT_REVISION,
28
+ allow_patterns=f"{CHECKPOINT_SUBFOLDER}/*",
29
+ local_dir=LOCAL_CHECKPOINT_DIR,
30
+ local_dir_use_symlinks=False
31
+ )
32
+ logger.info("Checkpoint files downloaded successfully.")
33
+ except Exception as e:
34
+ logger.error(f"Error downloading checkpoint files: {str(e)}")
35
+ raise e
36
 
37
  # Load model and tokenizer
38
  try:
39
  logger.info("Loading tokenizer...")
40
+ # Fallback to base Llama-3-8B tokenizer if CADFusion-specific config is missing
41
  tokenizer = AutoTokenizer.from_pretrained(
42
+ "meta-llama/Meta-Llama-3-8B",
 
 
43
  trust_remote_code=True
44
  )
45
  logger.info("Loading model...")
46
+ # Attempt to load model from local checkpoint
47
  model = AutoModelForCausalLM.from_pretrained(
48
+ LOCAL_CHECKPOINT_DIR,
 
 
49
  torch_dtype=torch.float16,
50
  device_map="auto",
51
  trust_remote_code=True
 
53
  logger.info("Model and tokenizer loaded successfully.")
54
  except Exception as e:
55
  logger.error(f"Error loading model or tokenizer: {str(e)}")
56
+ # Fallback to base Llama-3-8B model if local checkpoint fails
57
+ logger.info("Falling back to base Meta-Llama-3-8B model...")
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ "meta-llama/Meta-Llama-3-8B",
60
+ torch_dtype=torch.float16,
61
+ device_map="auto",
62
+ trust_remote_code=True
63
+ )
64
+ logger.info("Fallback model loaded successfully.")
65
 
66
  # Function to generate CAD model from text description
67
  def generate_cad_model(text_description):
 
126
  gr.Markdown("""
127
  **Note**:
128
  - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
129
+ - This deployment uses a fallback to Meta-Llama-3-8B due to potential issues with the v1_1 checkpoint.
130
+ - For full functionality, refer to the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion) for custom setup instructions.
131
  """)
132
 
133
  return demo