yukee1992 commited on
Commit
cf246af
·
verified ·
1 Parent(s): 1102b25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -46
app.py CHANGED
@@ -6,54 +6,26 @@ import gradio as gr
6
  # Configuration
7
  MODEL_ID = "google/gemma-1.1-7b-it"
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
- MAX_TOKENS = 200 # Reduced for stability
10
 
11
- # Initialize components
12
- tokenizer = None
13
- model = None
14
-
15
- def load_model():
16
- """Lazy-load the model to avoid immediate memory issues"""
17
- global tokenizer, model
18
- if model is None:
19
- print("🚀 Loading model...")
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
21
- model = AutoModelForCausalLM.from_pretrained(
22
- MODEL_ID,
23
- device_map="auto",
24
- torch_dtype=torch.float16,
25
- token=HF_TOKEN
26
- )
27
- print("✅ Model loaded!")
28
 
29
- def generate_script(topic):
30
- """The prediction function that handles requests"""
31
- try:
32
- load_model()
33
- prompt = f"Generate a short YouTube script about {topic}:"
34
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
35
- outputs = model.generate(
36
- **inputs,
37
- max_new_tokens=MAX_TOKENS,
38
- temperature=0.7
39
- )
40
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
41
- except Exception as e:
42
- print(f"⚠️ Error: {str(e)}")
43
- return f"Generation failed: {str(e)}"
44
 
45
- # Create Gradio app
46
- app = gr.Interface(
47
- fn=generate_script,
48
  inputs=gr.Textbox(label="Topic"),
49
  outputs=gr.Textbox(label="Script"),
50
- title="Gemma-7B Script Generator"
51
- )
52
-
53
- # Launch with explicit API configuration
54
- app.launch(
55
- server_name="0.0.0.0",
56
- server_port=7860,
57
- enable_api=True, # CRITICAL FOR API
58
- share=False
59
- )
 
6
  # Configuration
7
  MODEL_ID = "google/gemma-1.1-7b-it"
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
+ MAX_TOKENS = 150 # Reduced for CPU
10
 
11
+ # Load model (CPU-only)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_ID,
14
+ device_map="cpu",
15
+ torch_dtype=torch.float32,
16
+ token=HF_TOKEN
17
+ )
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
19
 
20
+ def predict(topic):
21
+ prompt = f"Create a short script about {topic}:"
22
+ inputs = tokenizer(prompt, return_tensors="pt")
23
+ outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS)
24
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
25
 
26
+ gr.Interface(
27
+ fn=predict,
 
28
  inputs=gr.Textbox(label="Topic"),
29
  outputs=gr.Textbox(label="Script"),
30
+ api_name="predict"
31
+ ).launch(server_name="0.0.0.0")