yukee1992 commited on
Commit
628eb7f
·
verified ·
1 Parent(s): 29b7918

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -39
app.py CHANGED
@@ -2,52 +2,111 @@ import os
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
 
5
 
6
  # Configuration
7
- MODEL_ID = "google/gemma-1.1-2b-it" # Using smaller 2B version
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
- MAX_TOKENS = 80 # Conservative limit
10
-
11
- def load_model():
12
- """Simplified model loading that works in Spaces"""
13
- print("🔄 Loading model...")
14
-
15
- # Load tokenizer
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
17
-
18
- # Explicit CPU-only loading
19
- model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_ID,
21
- torch_dtype=torch.float32, # Required for CPU
22
- token=HF_TOKEN
23
- ).to('cpu') # Explicit CPU placement
24
-
25
- print("✅ Model loaded successfully!")
26
- return tokenizer, model
27
-
28
- tokenizer, model = load_model()
29
-
30
- def predict(topic):
31
- """Memory-safe generation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  try:
 
 
 
 
 
 
 
33
  prompt = f"Create a short script about {topic}:\n1) Hook\n2) Point\n3) CTA\n\nScript:"
34
- inputs = tokenizer(prompt, return_tensors="pt").to('cpu')
35
 
36
- outputs = model.generate(
37
- **inputs,
38
- max_new_tokens=MAX_TOKENS,
39
- temperature=0.7
40
- )
 
 
 
 
 
41
 
42
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
43
-
44
  except Exception as e:
45
  return f"Error: {str(e)}"
46
 
47
- # Minimal interface
48
- gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  fn=predict,
50
- inputs=gr.Textbox(label="Topic"),
51
- outputs=gr.Textbox(label="Script", lines=4),
52
- api_name="predict"
53
- ).launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import gradio as gr
5
+ from typing import Dict, Any
6
 
7
  # Configuration
8
+ MODEL_ID = "google/gemma-1.1-2b-it"
9
+ HF_TOKEN = os.getenv("HF_TOKEN", "") # Default empty if not set
10
+ MAX_TOKENS = 80
11
+ MAX_INPUT_LENGTH = 100
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ class ScriptGenerator:
15
+ def __init__(self):
16
+ self.tokenizer = None
17
+ self.model = None
18
+ self.loaded = False
19
+
20
+ def load_model(self):
21
+ """Safe model loading with progress tracking"""
22
+ if self.loaded:
23
+ return
24
+
25
+ print("🔄 Loading model...")
26
+ try:
27
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
28
+
29
+ self.model = AutoModelForCausalLM.from_pretrained(
30
+ MODEL_ID,
31
+ torch_dtype=torch.float32 if DEVICE == "cpu" else torch.float16,
32
+ device_map="auto" if DEVICE == "cuda" else None,
33
+ token=HF_TOKEN
34
+ ).to(DEVICE)
35
+
36
+ self.loaded = True
37
+ print("✅ Model loaded successfully!")
38
+ except Exception as e:
39
+ print(f"❌ Model loading failed: {str(e)}")
40
+ raise
41
+
42
+ generator = ScriptGenerator()
43
+
44
+ def predict(topic: str) -> str:
45
+ """Generate script with proper error handling"""
46
  try:
47
+ # Input validation
48
+ if not topic or len(topic) > MAX_INPUT_LENGTH:
49
+ return f"Topic must be 1-{MAX_INPUT_LENGTH} characters"
50
+
51
+ if not generator.loaded:
52
+ generator.load_model()
53
+
54
  prompt = f"Create a short script about {topic}:\n1) Hook\n2) Point\n3) CTA\n\nScript:"
 
55
 
56
+ with torch.no_grad():
57
+ inputs = generator.tokenizer(prompt, return_tensors="pt").to(DEVICE)
58
+ outputs = generator.model.generate(
59
+ **inputs,
60
+ max_new_tokens=MAX_TOKENS,
61
+ temperature=0.7,
62
+ do_sample=True
63
+ )
64
+
65
+ return generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
66
 
67
+ except torch.cuda.OutOfMemoryError:
68
+ return "Error: GPU out of memory - try a shorter input"
69
  except Exception as e:
70
  return f"Error: {str(e)}"
71
 
72
+ def api_predict(data: Dict[str, Any]) -> Dict[str, Any]:
73
+ """Dedicated API endpoint with standardized response"""
74
+ try:
75
+ topic = data.get("topic", "")
76
+ result = predict(topic)
77
+ return {
78
+ "success": not result.startswith("Error"),
79
+ "result": result,
80
+ "error": result if result.startswith("Error") else None
81
+ }
82
+ except Exception as e:
83
+ return {
84
+ "success": False,
85
+ "result": None,
86
+ "error": str(e)
87
+ }
88
+
89
+ # Gradio Interface with explicit API
90
+ interface = gr.Interface(
91
  fn=predict,
92
+ inputs=gr.Textbox(label="Topic", placeholder="Enter your script topic..."),
93
+ outputs=gr.Textbox(label="Generated Script", lines=5),
94
+ title="Gemma Script Generator",
95
+ description="Generate marketing scripts using Gemma 2B"
96
+ )
97
+
98
+ # Mount both UI and API
99
+ app = gr.mount_gradio_app(
100
+ gr.App(),
101
+ interface,
102
+ path="/"
103
+ )
104
+ app.add_api_route("/api/predict", api_predict, methods=["POST"])
105
+
106
+ if __name__ == "__main__":
107
+ generator.load_model()
108
+ app.launch(
109
+ server_name="0.0.0.0",
110
+ server_port=int(os.getenv("PORT", 7860)),
111
+ share=False
112
+ )