Juna190825 commited on
Commit
d39dd11
·
verified ·
1 Parent(s): 11dc29d

Update Dockerfile

Browse files
Files changed (1) hide show
  1. app.py +92 -32
app.py CHANGED
@@ -41,56 +41,116 @@
41
  # demo.launch(server_name="0.0.0.0", server_port=7860)
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  import gradio as gr
45
  from transformers import AutoModelForCausalLM, AutoTokenizer
46
- from huggingface_hub import login, hf_hub_download
47
- from tenacity import retry, stop_after_attempt, wait_exponential
48
  import torch
49
  import os
 
50
 
51
  # Authentication
52
  login(token=os.getenv('HF_TOKEN'))
53
 
54
  # Configuration
55
- CACHE_REPO = "Juna190825/cacheRepo" # Your dataset repo for cached models
56
- MODEL_ID = "meta-llama/Llama-2-7b-chat-hf" # Original model ID
57
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
 
59
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
60
- def load_model():
61
- retries = 3
62
- for attempt in range(retries):
63
  try:
64
- # First try loading from cache repo
65
  model = AutoModelForCausalLM.from_pretrained(
66
- CACHE_REPO,
67
- cache_dir="/cache/models",
68
- local_files_only=True
69
  ).to(DEVICE)
70
  tokenizer = AutoTokenizer.from_pretrained(
71
- CACHE_REPO,
72
- cache_dir="/cache/models"
73
  )
74
- print("Loaded model from cache repo")
75
  return model, tokenizer
76
  except Exception as e:
77
- if attempt == retries - 1: # Final attempt
78
- print(f"Cache load failed: {str(e)}. Falling back to original repo")
79
- # Fallback to original repo
80
- model = AutoModelForCausalLM.from_pretrained(
81
- MODEL_ID,
82
- cache_dir="/cache/models"
83
- ).to(DEVICE)
84
- tokenizer = AutoTokenizer.from_pretrained(
85
- MODEL_ID,
86
- cache_dir="/cache/models"
87
- )
88
- return model, tokenizer
89
- print(f"Attempt {attempt + 1} failed, retrying...")
90
- time.sleep(2 ** attempt) # Exponential backoff
91
-
92
- # Load model and tokenizer
93
- model, tokenizer = load_model()
94
 
95
  def generate_text(prompt, max_length=200):
96
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
41
  # demo.launch(server_name="0.0.0.0", server_port=7860)
42
 
43
 
44
+ # import gradio as gr
45
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
46
+ # from huggingface_hub import login, hf_hub_download
47
+ # from tenacity import retry, stop_after_attempt, wait_exponential
48
+ # import torch
49
+ # import os
50
+
51
+ # # Authentication
52
+ # login(token=os.getenv('HF_TOKEN'))
53
+
54
+ # # Configuration
55
+ # CACHE_REPO = "Juna190825/cacheRepo" # Your dataset repo for cached models
56
+ # MODEL_ID = "meta-llama/Llama-2-7b-chat-hf" # Original model ID
57
+ # DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
+
59
+ # @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
60
+ # def load_model():
61
+ # retries = 3
62
+ # for attempt in range(retries):
63
+ # try:
64
+ # # First try loading from cache repo
65
+ # model = AutoModelForCausalLM.from_pretrained(
66
+ # CACHE_REPO,
67
+ # cache_dir="/cache/models",
68
+ # local_files_only=True
69
+ # ).to(DEVICE)
70
+ # tokenizer = AutoTokenizer.from_pretrained(
71
+ # CACHE_REPO,
72
+ # cache_dir="/cache/models"
73
+ # )
74
+ # print("Loaded model from cache repo")
75
+ # return model, tokenizer
76
+ # except Exception as e:
77
+ # if attempt == retries - 1: # Final attempt
78
+ # print(f"Cache load failed: {str(e)}. Falling back to original repo")
79
+ # # Fallback to original repo
80
+ # model = AutoModelForCausalLM.from_pretrained(
81
+ # MODEL_ID,
82
+ # cache_dir="/cache/models"
83
+ # ).to(DEVICE)
84
+ # tokenizer = AutoTokenizer.from_pretrained(
85
+ # MODEL_ID,
86
+ # cache_dir="/cache/models"
87
+ # )
88
+ # return model, tokenizer
89
+ # print(f"Attempt {attempt + 1} failed, retrying...")
90
+ # time.sleep(2 ** attempt) # Exponential backoff
91
+
92
+ # # Load model and tokenizer
93
+ # model, tokenizer = load_model()
94
+
95
+ # def generate_text(prompt, max_length=200):
96
+ # inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
97
+ # outputs = model.generate(
98
+ # **inputs,
99
+ # max_new_tokens=max_length,
100
+ # temperature=0.7,
101
+ # do_sample=True
102
+ # )
103
+ # return tokenizer.decode(outputs[0], skip_special_tokens=True)
104
+
105
+ # # Gradio interface
106
+ # with gr.Blocks() as demo:
107
+ # gr.Markdown("# LLaMA 2 7B Chat Demo")
108
+ # with gr.Row():
109
+ # input_text = gr.Textbox(label="Input Prompt", lines=3)
110
+ # output_text = gr.Textbox(label="Generated Response", lines=3)
111
+ # generate_btn = gr.Button("Generate")
112
+ # generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)
113
+
114
+ # demo.launch(server_name="0.0.0.0", server_port=7860)
115
+
116
  import gradio as gr
117
  from transformers import AutoModelForCausalLM, AutoTokenizer
118
+ from huggingface_hub import login
 
119
  import torch
120
  import os
121
+ import time # For manual retries
122
 
123
  # Authentication
124
  login(token=os.getenv('HF_TOKEN'))
125
 
126
  # Configuration
127
+ MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
128
+ CACHE_DIR = "/cache/models"
129
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
130
 
131
+ def load_model_with_retry(max_retries=3):
132
+ for attempt in range(max_retries):
 
 
133
  try:
134
+ # Try loading from cache first
135
  model = AutoModelForCausalLM.from_pretrained(
136
+ MODEL_ID,
137
+ cache_dir=CACHE_DIR,
138
+ local_files_only=(attempt > 0) # Only check cache after first fail
139
  ).to(DEVICE)
140
  tokenizer = AutoTokenizer.from_pretrained(
141
+ MODEL_ID,
142
+ cache_dir=CACHE_DIR
143
  )
 
144
  return model, tokenizer
145
  except Exception as e:
146
+ if attempt == max_retries - 1:
147
+ raise
148
+ wait_time = 2 ** (attempt + 1) # Exponential backoff (2s, 4s, 8s)
149
+ print(f"Attempt {attempt + 1} failed, retrying in {wait_time}s...")
150
+ time.sleep(wait_time)
151
+
152
+ # Load model
153
+ model, tokenizer = load_model_with_retry()
 
 
 
 
 
 
 
 
 
154
 
155
  def generate_text(prompt, max_length=200):
156
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)