yukee1992 commited on
Commit
a70d906
Β·
verified Β·
1 Parent(s): 06085f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -49
app.py CHANGED
@@ -13,8 +13,8 @@ from contextlib import asynccontextmanager
13
  # Configuration
14
  MODEL_ID = "google/gemma-1.1-2b-it"
15
  HF_TOKEN = os.getenv("HF_TOKEN", "")
16
- MAX_TOKENS = 200
17
- DEVICE = "cpu" # Force CPU to avoid device_map issues
18
  PORT = int(os.getenv("PORT", 7860))
19
 
20
  # Setup logging
@@ -32,36 +32,54 @@ class ScriptGenerator:
32
  self.tokenizer = None
33
  self.model = None
34
  self.loaded = False
 
35
 
36
  def load_model(self):
37
  if self.loaded:
38
- return
39
-
40
  logger.info("Loading model...")
41
  try:
42
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
 
 
 
 
 
 
 
43
  self.model = AutoModelForCausalLM.from_pretrained(
44
  MODEL_ID,
45
  torch_dtype=torch.float32,
46
  token=HF_TOKEN,
47
- low_cpu_mem_usage=True
48
  )
49
- # Simple device assignment without device_map
 
50
  self.model = self.model.to(DEVICE)
 
 
51
  self.loaded = True
52
  logger.info("βœ… Model loaded successfully")
 
 
53
  except Exception as e:
54
- logger.error(f"❌ Model loading failed: {str(e)}")
55
- raise
 
 
 
 
56
 
57
  @asynccontextmanager
58
  async def lifespan(app: FastAPI):
59
- generator = ScriptGenerator()
60
- generator.load_model()
 
 
61
  yield
62
 
63
  app = FastAPI(lifespan=lifespan)
64
- generator = ScriptGenerator()
65
 
66
  def extract_topic(topic_input: Union[str, List[str]]) -> str:
67
  """Extract topic from string or array input"""
@@ -74,40 +92,52 @@ def extract_topic(topic_input: Union[str, List[str]]) -> str:
74
  def generate_script(topic: str) -> str:
75
  """Generate script with error handling"""
76
  try:
 
 
 
 
 
77
  clean_topic = topic.strip().strip("['").strip("']").strip('"').strip("'")
78
  logger.info(f"🎯 Generating script for: '{clean_topic}'")
79
 
80
  prompt = (
81
- f"Create a short 1-minute video script about: {clean_topic[:80]}\n\n"
82
- "Structure:\n"
83
- "1) Hook (5-10 seconds)\n"
84
- "2) Main Content (40 seconds)\n"
85
- "3) CTA (5-10 seconds)\n\n"
86
  "Script:"
87
  )
88
 
 
89
  inputs = generator.tokenizer(
90
  prompt,
91
  return_tensors="pt",
92
- padding=True,
93
  truncation=True,
94
- max_length=512
95
- ).to(DEVICE)
 
 
 
96
 
97
- # Generate with safer parameters
98
  with torch.no_grad():
99
  outputs = generator.model.generate(
100
  **inputs,
101
  max_new_tokens=MAX_TOKENS,
102
  do_sample=True,
103
- top_p=0.8,
104
  temperature=0.7,
105
- pad_token_id=generator.tokenizer.eos_token_id
 
106
  )
107
 
 
108
  script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
109
  clean_script = script.replace(prompt, "").strip()
110
 
 
 
 
111
  logger.info(f"πŸ“ Generated {len(clean_script)} characters")
112
  return clean_script
113
 
@@ -126,7 +156,6 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
126
  "status": "complete",
127
  "result": script,
128
  "topic": topic,
129
- "original_input": topic_input,
130
  "script_length": len(script)
131
  }
132
 
@@ -135,35 +164,28 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
135
  if callback_url:
136
  try:
137
  async with httpx.AsyncClient(timeout=30.0) as client:
138
- webhook_data = {
139
- "job_id": job_id,
140
- "status": "complete",
141
- "result": script,
142
- "topic": topic,
143
- "original_input": topic_input
144
- }
145
-
146
  response = await client.post(
147
  callback_url,
148
- json=webhook_data,
 
 
 
 
 
149
  headers={"Content-Type": "application/json"}
150
  )
151
-
152
  logger.info(f"πŸ“¨ Webhook status: {response.status_code}")
153
-
154
  except Exception as e:
155
  logger.error(f"❌ Webhook failed: {str(e)}")
156
 
157
  except Exception as e:
158
  error_msg = f"Job failed: {str(e)}"
159
- logger.error(f"❌ Job {job_id} failed: {error_msg}", exc_info=True)
160
 
161
  jobs[job_id] = {
162
  "status": "failed",
163
  "error": error_msg,
164
- "topic": extract_topic(topic_input) if topic_input else "unknown",
165
- "original_input": topic_input,
166
- "script_length": 0
167
  }
168
 
169
  @app.post("/api/submit")
@@ -184,10 +206,8 @@ async def submit_job(request: Request, background_tasks: BackgroundTasks):
184
 
185
  jobs[job_id] = {
186
  "status": "processing",
187
- "result": None,
188
  "callback_url": callback_url,
189
- "topic": topic,
190
- "original_input": topic_input
191
  }
192
 
193
  background_tasks.add_task(
@@ -200,12 +220,11 @@ async def submit_job(request: Request, background_tasks: BackgroundTasks):
200
  return JSONResponse({
201
  "job_id": job_id,
202
  "status": "queued",
203
- "received_topic": topic,
204
- "callback_url": callback_url
205
  })
206
 
207
  except Exception as e:
208
- logger.error(f"❌ Submission error: {str(e)}", exc_info=True)
209
  raise HTTPException(status_code=400, detail=str(e))
210
 
211
  @app.get("/api/status/{job_id}")
@@ -235,8 +254,9 @@ async def debug_jobs():
235
  async def health_check():
236
  """Health check endpoint"""
237
  return {
238
- "status": "healthy",
239
  "model_loaded": generator.loaded,
 
240
  "total_jobs": len(jobs)
241
  }
242
 
@@ -244,11 +264,44 @@ async def health_check():
244
  async def test_generation():
245
  """Test script generation"""
246
  try:
247
- test_topic = "healthy lifestyle tips"
 
 
 
 
 
 
 
 
 
 
248
  script = generate_script(test_topic)
249
- return {"status": "success", "topic": test_topic, "script": script}
 
 
 
 
 
 
 
250
  except Exception as e:
251
- return {"status": "error", "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  if __name__ == "__main__":
254
  uvicorn.run(
 
13
  # Configuration
14
  MODEL_ID = "google/gemma-1.1-2b-it"
15
  HF_TOKEN = os.getenv("HF_TOKEN", "")
16
+ MAX_TOKENS = 150
17
+ DEVICE = "cpu"
18
  PORT = int(os.getenv("PORT", 7860))
19
 
20
  # Setup logging
 
32
  self.tokenizer = None
33
  self.model = None
34
  self.loaded = False
35
+ self.load_error = None
36
 
37
  def load_model(self):
38
  if self.loaded:
39
+ return True
40
+
41
  logger.info("Loading model...")
42
  try:
43
+ # Load tokenizer first
44
+ self.tokenizer = AutoTokenizer.from_pretrained(
45
+ MODEL_ID,
46
+ token=HF_TOKEN
47
+ )
48
+ logger.info("βœ… Tokenizer loaded")
49
+
50
+ # Load model with simple configuration
51
  self.model = AutoModelForCausalLM.from_pretrained(
52
  MODEL_ID,
53
  torch_dtype=torch.float32,
54
  token=HF_TOKEN,
55
+ device_map=None # Explicitly set to None
56
  )
57
+
58
+ # Move to device
59
  self.model = self.model.to(DEVICE)
60
+ self.model.eval() # Set to evaluation mode
61
+
62
  self.loaded = True
63
  logger.info("βœ… Model loaded successfully")
64
+ return True
65
+
66
  except Exception as e:
67
+ self.load_error = str(e)
68
+ logger.error(f"❌ Model loading failed: {str(e)}", exc_info=True)
69
+ return False
70
+
71
+ # Global generator instance
72
+ generator = ScriptGenerator()
73
 
74
  @asynccontextmanager
75
  async def lifespan(app: FastAPI):
76
+ # Load model during startup
77
+ success = generator.load_model()
78
+ if not success:
79
+ logger.critical("❌ Failed to load model during startup!")
80
  yield
81
 
82
  app = FastAPI(lifespan=lifespan)
 
83
 
84
  def extract_topic(topic_input: Union[str, List[str]]) -> str:
85
  """Extract topic from string or array input"""
 
92
  def generate_script(topic: str) -> str:
93
  """Generate script with error handling"""
94
  try:
95
+ # Check if model is loaded
96
+ if not generator.loaded:
97
+ if not generator.load_model():
98
+ raise Exception(f"Model failed to load: {generator.load_error}")
99
+
100
  clean_topic = topic.strip().strip("['").strip("']").strip('"').strip("'")
101
  logger.info(f"🎯 Generating script for: '{clean_topic}'")
102
 
103
  prompt = (
104
+ f"Create a 60-second video script about: {clean_topic[:50]}\n\n"
105
+ "1) Hook (10s)\n"
106
+ "2) Content (40s)\n"
107
+ "3) CTA (10s)\n\n"
 
108
  "Script:"
109
  )
110
 
111
+ # Tokenize input
112
  inputs = generator.tokenizer(
113
  prompt,
114
  return_tensors="pt",
 
115
  truncation=True,
116
+ max_length=256
117
+ )
118
+
119
+ # Move to device
120
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
121
 
122
+ # Generate text
123
  with torch.no_grad():
124
  outputs = generator.model.generate(
125
  **inputs,
126
  max_new_tokens=MAX_TOKENS,
127
  do_sample=True,
128
+ top_p=0.9,
129
  temperature=0.7,
130
+ pad_token_id=generator.tokenizer.eos_token_id,
131
+ num_return_sequences=1
132
  )
133
 
134
+ # Decode output
135
  script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
136
  clean_script = script.replace(prompt, "").strip()
137
 
138
+ if not clean_script:
139
+ clean_script = "Script generation completed but returned empty content."
140
+
141
  logger.info(f"πŸ“ Generated {len(clean_script)} characters")
142
  return clean_script
143
 
 
156
  "status": "complete",
157
  "result": script,
158
  "topic": topic,
 
159
  "script_length": len(script)
160
  }
161
 
 
164
  if callback_url:
165
  try:
166
  async with httpx.AsyncClient(timeout=30.0) as client:
 
 
 
 
 
 
 
 
167
  response = await client.post(
168
  callback_url,
169
+ json={
170
+ "job_id": job_id,
171
+ "status": "complete",
172
+ "result": script,
173
+ "topic": topic
174
+ },
175
  headers={"Content-Type": "application/json"}
176
  )
 
177
  logger.info(f"πŸ“¨ Webhook status: {response.status_code}")
 
178
  except Exception as e:
179
  logger.error(f"❌ Webhook failed: {str(e)}")
180
 
181
  except Exception as e:
182
  error_msg = f"Job failed: {str(e)}"
183
+ logger.error(f"❌ Job {job_id} failed: {error_msg}")
184
 
185
  jobs[job_id] = {
186
  "status": "failed",
187
  "error": error_msg,
188
+ "topic": extract_topic(topic_input) if topic_input else "unknown"
 
 
189
  }
190
 
191
  @app.post("/api/submit")
 
206
 
207
  jobs[job_id] = {
208
  "status": "processing",
 
209
  "callback_url": callback_url,
210
+ "topic": topic
 
211
  }
212
 
213
  background_tasks.add_task(
 
220
  return JSONResponse({
221
  "job_id": job_id,
222
  "status": "queued",
223
+ "topic": topic
 
224
  })
225
 
226
  except Exception as e:
227
+ logger.error(f"❌ Submission error: {str(e)}")
228
  raise HTTPException(status_code=400, detail=str(e))
229
 
230
  @app.get("/api/status/{job_id}")
 
254
  async def health_check():
255
  """Health check endpoint"""
256
  return {
257
+ "status": "healthy" if generator.loaded else "unhealthy",
258
  "model_loaded": generator.loaded,
259
+ "model_error": generator.load_error,
260
  "total_jobs": len(jobs)
261
  }
262
 
 
264
  async def test_generation():
265
  """Test script generation"""
266
  try:
267
+ # Check if model is loaded first
268
+ if not generator.loaded:
269
+ if not generator.load_model():
270
+ return {
271
+ "status": "error",
272
+ "error": f"Model failed to load: {generator.load_error}"
273
+ }
274
+
275
+ test_topic = "healthy lifestyle"
276
+ logger.info(f"πŸ§ͺ Testing generation with: {test_topic}")
277
+
278
  script = generate_script(test_topic)
279
+
280
+ return {
281
+ "status": "success",
282
+ "topic": test_topic,
283
+ "script_length": len(script),
284
+ "script_preview": script[:200] + "..." if len(script) > 200 else script
285
+ }
286
+
287
  except Exception as e:
288
+ logger.error(f"❌ Test generation failed: {str(e)}", exc_info=True)
289
+ return {
290
+ "status": "error",
291
+ "error": str(e),
292
+ "model_loaded": generator.loaded,
293
+ "model_error": generator.load_error
294
+ }
295
+
296
+ @app.get("/test/model")
297
+ async def test_model():
298
+ """Test if model loads correctly"""
299
+ return {
300
+ "model_loaded": generator.loaded,
301
+ "model_error": generator.load_error,
302
+ "has_tokenizer": generator.tokenizer is not None,
303
+ "has_model": generator.model is not None
304
+ }
305
 
306
  if __name__ == "__main__":
307
  uvicorn.run(