yukee1992 commited on
Commit
75f4ca1
·
verified ·
1 Parent(s): 4b7f7c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -41
app.py CHANGED
@@ -1,24 +1,28 @@
 
1
  import os
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import gradio as gr
5
 
6
  # Configuration
7
  MODEL_ID = "google/gemma-1.1-7b-it"
8
- HF_TOKEN = os.getenv("HF_TOKEN") # Make sure this is set in Space secrets
9
- MAX_TOKENS = 300
10
 
11
- # Initialize model
12
- tokenizer = AutoTokenizer.from_pretrained(
13
- MODEL_ID,
14
- token=HF_TOKEN
 
15
  )
16
 
17
- # Load model with CPU fallback
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_ID,
20
  device_map="auto",
21
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
22
  token=HF_TOKEN
23
  )
24
 
@@ -30,9 +34,7 @@ def generate_script(topic):
30
 
31
  Script:"""
32
 
33
- # Use current device (automatically handles CPU/GPU)
34
- device = "cuda" if torch.cuda.is_available() else "cpu"
35
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
36
 
37
  outputs = model.generate(
38
  **inputs,
@@ -45,39 +47,13 @@ def generate_script(topic):
45
 
46
  # Gradio interface
47
  with gr.Blocks() as demo:
48
- gr.Markdown("## 🚀 Gemma-7B Script Generator")
49
  with gr.Row():
50
- topic_input = gr.Textbox(
51
- label="Enter your topic",
52
- placeholder="e.g., 'intermittent fasting benefits'"
53
- )
54
  generate_btn = gr.Button("Generate", variant="primary")
55
-
56
- script_output = gr.Textbox(
57
- label="Generated Script",
58
- interactive=False,
59
- lines=10
60
- )
61
-
62
- # Examples
63
- gr.Examples(
64
- examples=[
65
- ["Why cold showers boost metabolism"],
66
- ["3 workout myths debunked by science"],
67
- ["The truth about protein timing"]
68
- ],
69
- inputs=topic_input
70
- )
71
 
72
  generate_btn.click(
73
  fn=generate_script,
74
  inputs=topic_input,
75
- outputs=script_output
76
- )
77
-
78
- # Launch with production settings
79
- demo.launch(
80
- server_name="0.0.0.0",
81
- server_port=7860,
82
- share=False
83
- )
 
1
+
2
  import os
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from transformers import BitsAndBytesConfig
6
  import gradio as gr
7
 
8
  # Configuration
9
  MODEL_ID = "google/gemma-1.1-7b-it"
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
+ MAX_TOKENS = 250 # Reduced for stability
12
 
13
+ # 4-bit quantization config
14
+ quant_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_compute_dtype=torch.float16,
17
+ bnb_4bit_quant_type="nf4"
18
  )
19
 
20
+ # Load tokenizer and model
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
22
  model = AutoModelForCausalLM.from_pretrained(
23
  MODEL_ID,
24
  device_map="auto",
25
+ quantization_config=quant_config,
26
  token=HF_TOKEN
27
  )
28
 
 
34
 
35
  Script:"""
36
 
37
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
38
 
39
  outputs = model.generate(
40
  **inputs,
 
47
 
48
  # Gradio interface
49
  with gr.Blocks() as demo:
50
+ gr.Markdown("## 🎥 Optimized Gemma-7B Generator")
51
  with gr.Row():
52
+ topic_input = gr.Textbox(label="Topic", placeholder="e.g., 'cold shower benefits'")
 
 
 
53
  generate_btn = gr.Button("Generate", variant="primary")
54
+ script_output = gr.Textbox(label="Script", lines=8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  generate_btn.click(
57
  fn=generate_script,
58
  inputs=topic_input,
59
+ outputs=script_output