Pasipid791 commited on
Commit
3f7ecb6
·
verified ·
1 Parent(s): 71f7c56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from huggingface_hub import hf_hub_download, snapshot_download
5
  import os
6
  import json
7
  import logging
@@ -15,6 +15,7 @@ MODEL_REPO = "microsoft/CADFusion"
15
  CHECKPOINT_REVISION = "main"
16
  CHECKPOINT_SUBFOLDER = "exp/model_ckpt/v1_1"
17
  LOCAL_CHECKPOINT_DIR = "./model_ckpt/v1_1"
 
18
 
19
  # Ensure local checkpoint directory exists
20
  os.makedirs(LOCAL_CHECKPOINT_DIR, exist_ok=True)
@@ -36,32 +37,37 @@ except Exception as e:
36
 
37
  # Load model and tokenizer
38
  try:
39
- logger.info("Loading tokenizer...")
40
- # Fallback to base Llama-3-8B tokenizer if CADFusion-specific config is missing
41
  tokenizer = AutoTokenizer.from_pretrained(
42
- "meta-llama/Meta-Llama-3-8B",
43
  trust_remote_code=True
44
  )
45
- logger.info("Loading model...")
46
- # Attempt to load model from local checkpoint
47
  model = AutoModelForCausalLM.from_pretrained(
48
  LOCAL_CHECKPOINT_DIR,
49
  torch_dtype=torch.float16,
50
  device_map="auto",
51
  trust_remote_code=True
52
  )
53
- logger.info("Model and tokenizer loaded successfully.")
54
  except Exception as e:
55
- logger.error(f"Error loading model or tokenizer: {str(e)}")
56
- # Fallback to base Llama-3-8B model if local checkpoint fails
57
- logger.info("Falling back to base Meta-Llama-3-8B model...")
58
- model = AutoModelForCausalLM.from_pretrained(
59
- "meta-llama/Meta-Llama-3-8B",
60
- torch_dtype=torch.float16,
61
- device_map="auto",
62
- trust_remote_code=True
63
- )
64
- logger.info("Fallback model loaded successfully.")
 
 
 
 
 
 
 
65
 
66
  # Function to generate CAD model from text description
67
  def generate_cad_model(text_description):
@@ -125,9 +131,10 @@ def create_gradio_interface():
125
 
126
  gr.Markdown("""
127
  **Note**:
 
128
  - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
129
- - This deployment uses a fallback to Meta-Llama-3-8B due to potential issues with the v1_1 checkpoint.
130
- - For full functionality, refer to the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion) for custom setup instructions.
131
  """)
132
 
133
  return demo
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from huggingface_hub import snapshot_download
5
  import os
6
  import json
7
  import logging
 
15
  CHECKPOINT_REVISION = "main"
16
  CHECKPOINT_SUBFOLDER = "exp/model_ckpt/v1_1"
17
  LOCAL_CHECKPOINT_DIR = "./model_ckpt/v1_1"
18
+ FALLBACK_MODEL = "meta-llama/Llama-2-7b"
19
 
20
  # Ensure local checkpoint directory exists
21
  os.makedirs(LOCAL_CHECKPOINT_DIR, exist_ok=True)
 
37
 
38
  # Load model and tokenizer
39
  try:
40
+ logger.info("Loading tokenizer from local checkpoint...")
 
41
  tokenizer = AutoTokenizer.from_pretrained(
42
+ LOCAL_CHECKPOINT_DIR,
43
  trust_remote_code=True
44
  )
45
+ logger.info("Loading model from local checkpoint...")
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
  LOCAL_CHECKPOINT_DIR,
48
  torch_dtype=torch.float16,
49
  device_map="auto",
50
  trust_remote_code=True
51
  )
52
+ logger.info("Model and tokenizer loaded successfully from local checkpoint.")
53
  except Exception as e:
54
+ logger.error(f"Error loading from local checkpoint: {str(e)}")
55
+ logger.info(f"Falling back to {FALLBACK_MODEL}...")
56
+ try:
57
+ tokenizer = AutoTokenizer.from_pretrained(
58
+ FALLBACK_MODEL,
59
+ trust_remote_code=True
60
+ )
61
+ model = AutoModelForCausalLM.from_pretrained(
62
+ FALLBACK_MODEL,
63
+ torch_dtype=torch.float16,
64
+ device_map="auto",
65
+ trust_remote_code=True
66
+ )
67
+ logger.info(f"Fallback model {FALLBACK_MODEL} loaded successfully.")
68
+ except Exception as fallback_e:
69
+ logger.error(f"Error loading fallback model: {str(fallback_e)}")
70
+ raise fallback_e
71
 
72
  # Function to generate CAD model from text description
73
  def generate_cad_model(text_description):
 
131
 
132
  gr.Markdown("""
133
  **Note**:
134
+ - This deployment may use a fallback model (Llama-2-7b) due to issues with the CADFusion v1_1 checkpoint.
135
  - CADFusion is for research purposes only. Generated models may not be technically accurate and require validation.
136
+ - For full CADFusion functionality, follow the setup instructions in the [CADFusion GitHub repo](https://github.com/microsoft/CADFusion).
137
+ - Contact Shizhao Sun (shizsu@microsoft.com) for checkpoint access or issues.
138
  """)
139
 
140
  return demo