yukee1992 commited on
Commit
6a274d8
·
verified ·
1 Parent(s): 4c52de3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -26
app.py CHANGED
@@ -1,34 +1,23 @@
 
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import gradio as gr
4
 
5
- # Load model (cache on first run)
6
  model_id = "google/gemma-1.1-7b-it"
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
- device_map="auto",
11
- torch_dtype=torch.bfloat16 # Optimized for NVIDIA
12
- )
13
-
14
- def generate_script(topic):
15
- prompt = f"""Generate a viral YouTube Short script about {topic} with:
16
- 1) HOOK: Controversial opening (5 words max)
17
- 2) BODY: 3 scientific facts
18
- 3) CTA: Actionable challenge
19
-
20
- Script:"""
21
-
22
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
23
- outputs = model.generate(**inputs, max_new_tokens=300)
24
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
25
-
26
- # Gradio UI
27
- demo = gr.Interface(
28
- fn=generate_script,
29
- inputs=gr.Textbox(label="Topic"),
30
- outputs=gr.Textbox(label="Generated Script"),
31
- title="Gemma-7B Script Generator"
32
  )
33
 
34
- demo.launch(server_port=7860, share=True)
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
 
 
3
 
 
4
  model_id = "google/gemma-1.1-7b-it"
5
+
6
+ # CPU-specific config
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
+ device_map="cpu",
11
+ torch_dtype=torch.float32, # Required for CPU
12
+ load_in_8bit=True # Reduces RAM usage by 2x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  )
14
 
15
+ def generate(prompt):
16
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
17
+ outputs = model.generate(
18
+ **inputs,
19
+ max_new_tokens=150, # Must stay under 200
20
+ do_sample=True,
21
+ temperature=0.7
22
+ )
23
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)