Juna190825 commited on
Commit
f4e52ec
·
verified ·
1 Parent(s): f2cbc81

Update Dockerfile

Browse files
Files changed (1) hide show
  1. app.py +86 -22
app.py CHANGED
@@ -1,45 +1,109 @@
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
4
  import torch
 
5
 
6
- # Load model (will use cached version if available)
7
- model_id = "meta-llama/Llama-2-7b-chat-hf"
8
 
9
- # Check for GPU
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
 
12
- # Load tokenizer and model
13
- tokenizer = AutoTokenizer.from_pretrained(model_id)
14
- model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def generate_text(prompt, max_length=200):
17
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
18
-
19
- # Generate response
20
  outputs = model.generate(
21
  **inputs,
22
  max_new_tokens=max_length,
23
  temperature=0.7,
24
  do_sample=True
25
  )
26
-
27
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- return response
29
 
30
- # Create Gradio interface
31
  with gr.Blocks() as demo:
32
  gr.Markdown("# LLaMA 2 7B Chat Demo")
33
  with gr.Row():
34
  input_text = gr.Textbox(label="Input Prompt", lines=3)
35
  output_text = gr.Textbox(label="Generated Response", lines=3)
36
-
37
  generate_btn = gr.Button("Generate")
38
- generate_btn.click(
39
- fn=generate_text,
40
- inputs=input_text,
41
- outputs=output_text
42
- )
43
 
44
  demo.launch(server_name="0.0.0.0", server_port=7860)
45
-
 
1
 
2
+ # import gradio as gr
3
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ # from huggingface_hub import login
5
+ # import torch
6
+ # import os
7
+
8
+ # # Authenticate using environment variable
9
+ # login(token=os.getenv('HF_TOKEN'))
10
+
11
+ # # Load model (will use cached version if available)
12
+ # model_id = "meta-llama/Llama-2-7b-chat-hf"
13
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # def load_model():
16
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ # model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
18
+ # return tokenizer, model
19
+
20
+ # tokenizer, model = load_model()
21
+
22
+ # def generate_text(prompt, max_length=200):
23
+ # inputs = tokenizer(prompt, return_tensors="pt").to(device)
24
+ # outputs = model.generate(
25
+ # **inputs,
26
+ # max_new_tokens=max_length,
27
+ # temperature=0.7,
28
+ # do_sample=True
29
+ # )
30
+ # return tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+
32
+ # # Gradio interface
33
+ # with gr.Blocks() as demo:
34
+ # gr.Markdown("# LLaMA 2 7B Chat Demo")
35
+ # with gr.Row():
36
+ # input_text = gr.Textbox(label="Input Prompt", lines=3)
37
+ # output_text = gr.Textbox(label="Generated Response", lines=3)
38
+ # generate_btn = gr.Button("Generate")
39
+ # generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)
40
+
41
+ # demo.launch(server_name="0.0.0.0", server_port=7860)
42
+
43
+
44
  import gradio as gr
45
+ from transformers import AutoModelForCausalLM, AutoTokenizer
46
+ from huggingface_hub import login, hf_hub_download
47
+ from tenacity import retry, stop_after_attempt, wait_exponential
48
  import torch
49
+ import os
50
 
51
+ # Authentication
52
+ login(token=os.getenv('HF_TOKEN'))
53
 
54
+ # Configuration
55
+ CACHE_REPO = "Juna190825/cacheRepo" # Your dataset repo for cached models
56
+ MODEL_ID = "meta-llama/Llama-2-7b-chat-hf" # Original model ID
57
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
 
59
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
60
+ def load_model():
61
+ try:
62
+ # First try loading from cache repo
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ CACHE_REPO,
65
+ cache_dir="/cache/models",
66
+ local_files_only=True
67
+ ).to(DEVICE)
68
+ tokenizer = AutoTokenizer.from_pretrained(
69
+ CACHE_REPO,
70
+ cache_dir="/cache/models"
71
+ )
72
+ print("Loaded model from cache repo")
73
+ return model, tokenizer
74
+ except Exception as e:
75
+ print(f"Cache load failed: {str(e)}. Falling back to original repo")
76
+ # Fallback to original repo
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ MODEL_ID,
79
+ cache_dir="/cache/models"
80
+ ).to(DEVICE)
81
+ tokenizer = AutoTokenizer.from_pretrained(
82
+ MODEL_ID,
83
+ cache_dir="/cache/models"
84
+ )
85
+ return model, tokenizer
86
+
87
+ # Load model and tokenizer
88
+ model, tokenizer = load_model()
89
 
90
  def generate_text(prompt, max_length=200):
91
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
 
92
  outputs = model.generate(
93
  **inputs,
94
  max_new_tokens=max_length,
95
  temperature=0.7,
96
  do_sample=True
97
  )
98
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
99
 
100
+ # Gradio interface
101
  with gr.Blocks() as demo:
102
  gr.Markdown("# LLaMA 2 7B Chat Demo")
103
  with gr.Row():
104
  input_text = gr.Textbox(label="Input Prompt", lines=3)
105
  output_text = gr.Textbox(label="Generated Response", lines=3)
 
106
  generate_btn = gr.Button("Generate")
107
+ generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)
 
 
 
 
108
 
109
  demo.launch(server_name="0.0.0.0", server_port=7860)