yukee1992 commited on
Commit
80aadbe
Β·
verified Β·
1 Parent(s): 8ae56b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -46
app.py CHANGED
@@ -3,7 +3,7 @@ import uuid
3
  import httpx
4
  import torch
5
  import logging
6
- import time
7
  from typing import Dict, Optional, List, Union
8
  from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends
9
  from fastapi.responses import JSONResponse
@@ -16,7 +16,7 @@ from contextlib import asynccontextmanager
16
  MODEL_ID = "google/gemma-1.1-2b-it"
17
  HF_TOKEN = os.getenv("HF_TOKEN", "")
18
  API_KEY = os.getenv("API_KEY", "default-key-123")
19
- MAX_TOKENS = 150
20
  DEVICE = "cpu"
21
  PORT = int(os.getenv("PORT", 7860))
22
 
@@ -72,39 +72,41 @@ class ScriptGenerator:
72
  generator = ScriptGenerator()
73
 
74
  async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
75
- """Verify API key - but allow Hugging Face monitoring"""
76
- # Allow internal Hugging Face IPs without API key for health checks
77
- # This prevents the constant model generation from their monitoring
78
  if credentials.credentials != API_KEY:
79
- # Check if this is likely Hugging Face internal monitoring
80
- # (you can add more sophisticated checks here if needed)
81
  raise HTTPException(status_code=401, detail="Invalid API key")
82
  return True
83
 
84
- def is_huggingface_monitoring(request: Request) -> bool:
85
- """Check if request is from Hugging Face monitoring"""
86
- client_host = request.client.host
87
- # Hugging Face internal IP ranges
88
- hf_ips = ["10.16.", "10.20.", "10.24."]
89
- return any(client_host.startswith(ip) for ip in hf_ips)
90
-
91
  @asynccontextmanager
92
  async def lifespan(app: FastAPI):
93
- # Load model but don't block startup
94
- # Model will load on first real request
95
  logger.info("πŸš€ API Server starting up...")
96
  yield
97
 
98
  app = FastAPI(lifespan=lifespan)
99
 
100
  def extract_topic(topic_input: Union[str, List[str]]) -> str:
 
101
  if isinstance(topic_input, list):
102
  if topic_input:
103
  return str(topic_input[0])
104
  return "No topic provided"
105
  return str(topic_input)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def generate_script(topic: str) -> str:
 
108
  try:
109
  if not generator.loaded:
110
  if not generator.load_model():
@@ -114,15 +116,30 @@ def generate_script(topic: str) -> str:
114
  logger.info(f"🎯 Generating script for: '{clean_topic}'")
115
 
116
  prompt = (
117
- f"Create a 60-second video script about: {clean_topic[:50]}\n\n"
118
- "1) Hook (10s)\n2) Content (40s)\n3) CTA (10s)\n\nScript:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
 
121
  inputs = generator.tokenizer(
122
  prompt,
123
  return_tensors="pt",
124
  truncation=True,
125
- max_length=256
126
  )
127
 
128
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
@@ -133,52 +150,64 @@ def generate_script(topic: str) -> str:
133
  max_new_tokens=MAX_TOKENS,
134
  do_sample=True,
135
  top_p=0.9,
136
- temperature=0.7,
137
  pad_token_id=generator.tokenizer.eos_token_id,
 
138
  )
139
 
140
  script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
141
  clean_script = script.replace(prompt, "").strip()
142
 
143
- if not clean_script:
144
- clean_script = "Script generation completed but returned empty content."
145
-
146
- logger.info(f"πŸ“ Generated {len(clean_script)} characters")
147
- return clean_script
148
 
149
  except Exception as e:
150
  logger.error(f"❌ Script generation failed: {str(e)}")
151
  raise
152
 
153
  async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_url: str = None):
 
154
  try:
155
  topic = extract_topic(topic_input)
156
  logger.info(f"🎯 Processing: '{topic}'")
157
 
158
  script = generate_script(topic)
 
 
159
  jobs[job_id] = {
160
  "status": "complete",
161
  "result": script,
162
  "topic": topic,
163
- "script_length": len(script)
 
164
  }
165
 
166
  logger.info(f"βœ… Completed job {job_id}")
167
 
 
168
  if callback_url:
169
  try:
170
  async with httpx.AsyncClient(timeout=30.0) as client:
 
 
 
 
 
 
 
 
 
171
  response = await client.post(
172
  callback_url,
173
- json={
174
- "job_id": job_id,
175
- "status": "complete",
176
- "result": script,
177
- "topic": topic
178
- },
179
  headers={"Content-Type": "application/json"}
180
  )
 
181
  logger.info(f"πŸ“¨ Webhook status: {response.status_code}")
 
182
  except Exception as e:
183
  logger.error(f"❌ Webhook failed: {str(e)}")
184
 
@@ -186,11 +215,28 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
186
  error_msg = f"Job failed: {str(e)}"
187
  logger.error(f"❌ Job {job_id} failed: {error_msg}")
188
 
 
189
  jobs[job_id] = {
190
  "status": "failed",
191
  "error": error_msg,
192
  "topic": extract_topic(topic_input) if topic_input else "unknown"
193
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  @app.post("/api/submit")
196
  async def submit_job(
@@ -198,11 +244,12 @@ async def submit_job(
198
  background_tasks: BackgroundTasks,
199
  auth: bool = Depends(verify_api_key)
200
  ):
201
- """Main endpoint for script generation"""
202
  try:
203
  data = await request.json()
204
  job_id = str(uuid.uuid4())
205
 
 
206
  if not data.get("topic"):
207
  raise HTTPException(status_code=400, detail="Topic is required")
208
 
@@ -212,12 +259,14 @@ async def submit_job(
212
 
213
  logger.info(f"πŸ“₯ Received job {job_id}: '{topic}'")
214
 
 
215
  jobs[job_id] = {
216
  "status": "processing",
217
  "callback_url": callback_url,
218
  "topic": topic
219
  }
220
 
 
221
  background_tasks.add_task(
222
  process_job,
223
  job_id,
@@ -228,7 +277,9 @@ async def submit_job(
228
  return JSONResponse({
229
  "job_id": job_id,
230
  "status": "queued",
231
- "topic": topic
 
 
232
  })
233
 
234
  except Exception as e:
@@ -243,26 +294,29 @@ async def get_status(job_id: str, auth: bool = Depends(verify_api_key)):
243
  return jobs[job_id]
244
 
245
  @app.get("/health")
246
- async def health_check(request: Request):
247
- """Health check endpoint - lightweight for monitoring"""
248
- # Return immediate response without model loading for monitoring
 
 
249
  return {
250
  "status": "healthy",
251
  "model_loaded": generator.loaded,
252
  "total_jobs": len(jobs),
253
- "monitoring": is_huggingface_monitoring(request)
 
 
254
  }
255
 
256
  @app.get("/test/generation")
257
- async def test_generation(request: Request, auth: bool = Depends(verify_api_key)):
258
- """Test endpoint - only works with API key"""
259
- # This won't be triggered by HF monitoring because it requires API key
260
  try:
261
  if not generator.loaded:
262
  if not generator.load_model():
263
  return {"status": "error", "error": "Model failed to load"}
264
 
265
- test_topic = "healthy lifestyle"
266
  logger.info(f"πŸ§ͺ Testing generation with: {test_topic}")
267
 
268
  script = generate_script(test_topic)
@@ -271,16 +325,29 @@ async def test_generation(request: Request, auth: bool = Depends(verify_api_key)
271
  "status": "success",
272
  "topic": test_topic,
273
  "script_length": len(script),
274
- "script_preview": script[:200] + "..." if len(script) > 200 else script
 
 
275
  }
276
 
277
  except Exception as e:
278
  logger.error(f"❌ Test generation failed: {str(e)}")
279
  return {"status": "error", "error": str(e)}
280
 
281
- # Remove public debug endpoints that were causing the issue
282
- # @app.get("/debug/jobs") - REMOVED
283
- # @app.get("/test/model") - REMOVED
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  if __name__ == "__main__":
286
  uvicorn.run(
 
3
  import httpx
4
  import torch
5
  import logging
6
+ import re
7
  from typing import Dict, Optional, List, Union
8
  from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends
9
  from fastapi.responses import JSONResponse
 
16
  MODEL_ID = "google/gemma-1.1-2b-it"
17
  HF_TOKEN = os.getenv("HF_TOKEN", "")
18
  API_KEY = os.getenv("API_KEY", "default-key-123")
19
+ MAX_TOKENS = 450
20
  DEVICE = "cpu"
21
  PORT = int(os.getenv("PORT", 7860))
22
 
 
72
  generator = ScriptGenerator()
73
 
74
  async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
75
+ """Verify API key"""
 
 
76
  if credentials.credentials != API_KEY:
 
 
77
  raise HTTPException(status_code=401, detail="Invalid API key")
78
  return True
79
 
 
 
 
 
 
 
 
80
  @asynccontextmanager
81
  async def lifespan(app: FastAPI):
 
 
82
  logger.info("πŸš€ API Server starting up...")
83
  yield
84
 
85
  app = FastAPI(lifespan=lifespan)
86
 
87
  def extract_topic(topic_input: Union[str, List[str]]) -> str:
88
+ """Extract topic from string or array input"""
89
  if isinstance(topic_input, list):
90
  if topic_input:
91
  return str(topic_input[0])
92
  return "No topic provided"
93
  return str(topic_input)
94
 
95
+ def format_script(script: str) -> str:
96
+ """Clean and format the generated script"""
97
+ # Remove any leftover prompt text
98
+ script = script.split("SCRIPT:")[-1].strip()
99
+
100
+ # Ensure proper line breaks for timestamps
101
+ script = re.sub(r'(\[\d+:\d+)', r'\n\1', script)
102
+
103
+ # Clean up multiple newlines
104
+ script = re.sub(r'\n\s*\n', '\n\n', script)
105
+
106
+ return script.strip()
107
+
108
  def generate_script(topic: str) -> str:
109
+ """Generate high-quality video script"""
110
  try:
111
  if not generator.loaded:
112
  if not generator.load_model():
 
116
  logger.info(f"🎯 Generating script for: '{clean_topic}'")
117
 
118
  prompt = (
119
+ f"Create a detailed 60-second YouTube/TikTok video script about: {clean_topic}\n\n"
120
+ "REQUIREMENTS:\n"
121
+ "- Total duration: 60 seconds exactly\n"
122
+ "- Engaging hook in first 5 seconds\n"
123
+ "- Clear structure with timestamps every 10-15 seconds\n"
124
+ "- Conversational, engaging tone for social media\n"
125
+ "- End with strong call-to-action\n"
126
+ "- Include both voiceover and visual descriptions\n"
127
+ "- Minimum 800 characters for proper 60-second video\n\n"
128
+ "SCRIPT FORMAT:\n"
129
+ "[0:00-0:05] HOOK: Grab attention immediately\n"
130
+ "[0:05-0:15] INTRODUCTION: Introduce topic and yourself\n"
131
+ "[0:15-0:45] MAIN CONTENT: 2-3 key points with examples\n"
132
+ "[0:45-0:55] BENEFIT: Why this matters to viewers\n"
133
+ "[0:55-1:00] CTA: Clear call to action (follow, comment, like)\n\n"
134
+ "Include both VOICEOVER and VISUAL descriptions.\n\n"
135
+ "SCRIPT:"
136
  )
137
 
138
  inputs = generator.tokenizer(
139
  prompt,
140
  return_tensors="pt",
141
  truncation=True,
142
+ max_length=512
143
  )
144
 
145
  inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
150
  max_new_tokens=MAX_TOKENS,
151
  do_sample=True,
152
  top_p=0.9,
153
+ temperature=0.8,
154
  pad_token_id=generator.tokenizer.eos_token_id,
155
+ repetition_penalty=1.1
156
  )
157
 
158
  script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
159
  clean_script = script.replace(prompt, "").strip()
160
 
161
+ # Format the script
162
+ formatted_script = format_script(clean_script)
163
+
164
+ logger.info(f"πŸ“ Generated {len(formatted_script)} characters")
165
+ return formatted_script
166
 
167
  except Exception as e:
168
  logger.error(f"❌ Script generation failed: {str(e)}")
169
  raise
170
 
171
  async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_url: str = None):
172
+ """Background task to process job"""
173
  try:
174
  topic = extract_topic(topic_input)
175
  logger.info(f"🎯 Processing: '{topic}'")
176
 
177
  script = generate_script(topic)
178
+
179
+ # Store job results
180
  jobs[job_id] = {
181
  "status": "complete",
182
  "result": script,
183
  "topic": topic,
184
+ "script_length": len(script),
185
+ "formatted": True
186
  }
187
 
188
  logger.info(f"βœ… Completed job {job_id}")
189
 
190
+ # Send webhook callback if URL provided
191
  if callback_url:
192
  try:
193
  async with httpx.AsyncClient(timeout=30.0) as client:
194
+ webhook_data = {
195
+ "job_id": job_id,
196
+ "status": "complete",
197
+ "result": script,
198
+ "topic": topic,
199
+ "script_length": len(script),
200
+ "formatted": True
201
+ }
202
+
203
  response = await client.post(
204
  callback_url,
205
+ json=webhook_data,
 
 
 
 
 
206
  headers={"Content-Type": "application/json"}
207
  )
208
+
209
  logger.info(f"πŸ“¨ Webhook status: {response.status_code}")
210
+
211
  except Exception as e:
212
  logger.error(f"❌ Webhook failed: {str(e)}")
213
 
 
215
  error_msg = f"Job failed: {str(e)}"
216
  logger.error(f"❌ Job {job_id} failed: {error_msg}")
217
 
218
+ # Store failure information
219
  jobs[job_id] = {
220
  "status": "failed",
221
  "error": error_msg,
222
  "topic": extract_topic(topic_input) if topic_input else "unknown"
223
  }
224
+
225
+ # Send failure webhook if callback URL exists
226
+ if callback_url:
227
+ try:
228
+ async with httpx.AsyncClient(timeout=10.0) as client:
229
+ await client.post(
230
+ callback_url,
231
+ json={
232
+ "job_id": job_id,
233
+ "status": "failed",
234
+ "error": error_msg,
235
+ "topic": extract_topic(topic_input) if topic_input else "unknown"
236
+ }
237
+ )
238
+ except Exception:
239
+ logger.error("Failed to send error webhook")
240
 
241
  @app.post("/api/submit")
242
  async def submit_job(
 
244
  background_tasks: BackgroundTasks,
245
  auth: bool = Depends(verify_api_key)
246
  ):
247
+ """Endpoint to submit new job"""
248
  try:
249
  data = await request.json()
250
  job_id = str(uuid.uuid4())
251
 
252
+ # Validate input
253
  if not data.get("topic"):
254
  raise HTTPException(status_code=400, detail="Topic is required")
255
 
 
259
 
260
  logger.info(f"πŸ“₯ Received job {job_id}: '{topic}'")
261
 
262
+ # Store initial job data
263
  jobs[job_id] = {
264
  "status": "processing",
265
  "callback_url": callback_url,
266
  "topic": topic
267
  }
268
 
269
+ # Process job in background
270
  background_tasks.add_task(
271
  process_job,
272
  job_id,
 
277
  return JSONResponse({
278
  "job_id": job_id,
279
  "status": "queued",
280
+ "topic": topic,
281
+ "estimated_time": "70-90 seconds",
282
+ "message": "Script generation started"
283
  })
284
 
285
  except Exception as e:
 
294
  return jobs[job_id]
295
 
296
  @app.get("/health")
297
+ async def health_check():
298
+ """Health check endpoint"""
299
+ completed_jobs = [job for job in jobs.values() if job.get("status") == "complete"]
300
+ avg_length = sum(job.get("script_length", 0) for job in completed_jobs) / max(1, len(completed_jobs))
301
+
302
  return {
303
  "status": "healthy",
304
  "model_loaded": generator.loaded,
305
  "total_jobs": len(jobs),
306
+ "completed_jobs": len(completed_jobs),
307
+ "failed_jobs": sum(1 for job in jobs.values() if job.get("status") == "failed"),
308
+ "average_script_length": round(avg_length, 2)
309
  }
310
 
311
  @app.get("/test/generation")
312
+ async def test_generation(auth: bool = Depends(verify_api_key)):
313
+ """Test script generation"""
 
314
  try:
315
  if not generator.loaded:
316
  if not generator.load_model():
317
  return {"status": "error", "error": "Model failed to load"}
318
 
319
+ test_topic = "the future of artificial intelligence in healthcare"
320
  logger.info(f"πŸ§ͺ Testing generation with: {test_topic}")
321
 
322
  script = generate_script(test_topic)
 
325
  "status": "success",
326
  "topic": test_topic,
327
  "script_length": len(script),
328
+ "script_preview": script[:300] + "..." if len(script) > 300 else script,
329
+ "estimated_duration": "60 seconds",
330
+ "quality": "good" if len(script) >= 800 else "needs improvement"
331
  }
332
 
333
  except Exception as e:
334
  logger.error(f"❌ Test generation failed: {str(e)}")
335
  return {"status": "error", "error": str(e)}
336
 
337
+ @app.get("/")
338
+ async def root():
339
+ """Root endpoint"""
340
+ return {
341
+ "message": "Video Script Generator API",
342
+ "version": "1.0",
343
+ "endpoints": {
344
+ "submit_job": "POST /api/submit",
345
+ "check_status": "GET /api/status/{job_id}",
346
+ "health": "GET /health",
347
+ "test": "GET /test/generation"
348
+ },
349
+ "status": "operational"
350
+ }
351
 
352
  if __name__ == "__main__":
353
  uvicorn.run(