yukee1992 commited on
Commit
7aa1710
·
verified ·
1 Parent(s): 1105426

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -32
app.py CHANGED
@@ -4,6 +4,7 @@ import httpx
4
  import torch
5
  import logging
6
  import re
 
7
  from typing import Dict, Optional, List, Union
8
  from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends
9
  from fastapi.responses import JSONResponse
@@ -84,13 +85,75 @@ async def lifespan(app: FastAPI):
84
 
85
  app = FastAPI(lifespan=lifespan)
86
 
87
- def extract_topic(topic_input: Union[str, List[str]]) -> str:
88
- """Extract topic from string or array input"""
89
- if isinstance(topic_input, list):
90
- if topic_input:
91
- return str(topic_input[0])
92
- return "No topic provided"
93
- return str(topic_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def format_script(script: str) -> str:
96
  """Clean and format the generated script"""
@@ -200,19 +263,28 @@ def generate_script(topic: str) -> str:
200
  logger.error(f"❌ Script generation failed: {str(e)}")
201
  raise
202
 
203
- async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_url: str = None):
204
  """Background task to process job"""
205
  try:
206
- topic = extract_topic(topic_input)
207
- logger.info(f"🎯 Processing: '{topic}'")
 
 
 
 
208
 
209
- script = generate_script(topic)
 
 
 
 
210
 
211
  # Store job results
212
  jobs[job_id] = {
213
  "status": "complete",
214
  "result": script,
215
- "topic": topic,
 
216
  "script_length": len(script),
217
  "formatted": True
218
  }
@@ -227,7 +299,8 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
227
  "job_id": job_id,
228
  "status": "complete",
229
  "result": script,
230
- "topic": topic,
 
231
  "script_length": len(script),
232
  "formatted": True
233
  }
@@ -251,7 +324,7 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
251
  jobs[job_id] = {
252
  "status": "failed",
253
  "error": error_msg,
254
- "topic": extract_topic(topic_input) if topic_input else "unknown"
255
  }
256
 
257
  # Send failure webhook if callback URL exists
@@ -264,7 +337,7 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
264
  "job_id": job_id,
265
  "status": "failed",
266
  "error": error_msg,
267
- "topic": extract_topic(topic_input) if topic_input else "unknown"
268
  }
269
  )
270
  except Exception:
@@ -282,36 +355,39 @@ async def submit_job(
282
  job_id = str(uuid.uuid4())
283
 
284
  # Validate input
285
- if not data.get("topic"):
286
- raise HTTPException(status_code=400, detail="Topic is required")
287
 
288
  callback_url = data.get("callback_url")
289
- topic_input = data["topic"]
290
- topic = extract_topic(topic_input)
291
 
292
- logger.info(f"📥 Received job {job_id}: '{topic}'")
 
 
 
293
 
294
  # Store initial job data
295
  jobs[job_id] = {
296
  "status": "processing",
297
  "callback_url": callback_url,
298
- "topic": topic
299
  }
300
 
301
  # Process job in background
302
  background_tasks.add_task(
303
  process_job,
304
  job_id,
305
- topic_input,
306
  callback_url
307
  )
308
 
309
  return JSONResponse({
310
  "job_id": job_id,
311
  "status": "queued",
312
- "topic": topic,
313
- "estimated_time": "70-90 seconds",
314
- "message": "Script generation started"
315
  })
316
 
317
  except Exception as e:
@@ -348,14 +424,26 @@ async def test_generation(auth: bool = Depends(verify_api_key)):
348
  if not generator.load_model():
349
  return {"status": "error", "error": "Model failed to load"}
350
 
351
- test_topic = "the future of artificial intelligence in healthcare"
352
- logger.info(f"🧪 Testing generation with: {test_topic}")
 
 
 
 
 
 
 
 
 
 
353
 
354
- script = generate_script(test_topic)
 
355
 
356
  return {
357
  "status": "success",
358
- "topic": test_topic,
 
359
  "script_length": len(script),
360
  "script_preview": script[:300] + "..." if len(script) > 300 else script,
361
  "estimated_duration": "60 seconds",
@@ -370,10 +458,11 @@ async def test_generation(auth: bool = Depends(verify_api_key)):
370
  async def root():
371
  """Root endpoint"""
372
  return {
373
- "message": "Video Script Generator API",
374
- "version": "1.0",
 
375
  "endpoints": {
376
- "submit_job": "POST /api/submit",
377
  "check_status": "GET /api/status/{job_id}",
378
  "health": "GET /health",
379
  "test": "GET /test/generation"
 
4
  import torch
5
  import logging
6
  import re
7
+ import json
8
  from typing import Dict, Optional, List, Union
9
  from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends
10
  from fastapi.responses import JSONResponse
 
85
 
86
  app = FastAPI(lifespan=lifespan)
87
 
88
+ def extract_topics(topic_input: Union[str, List[str]]) -> List[str]:
89
+ """Extract and validate topics from input"""
90
+ if isinstance(topic_input, str):
91
+ try:
92
+ # Try to parse as JSON if it's a string
93
+ parsed = json.loads(topic_input)
94
+ if isinstance(parsed, list):
95
+ return [str(topic).strip() for topic in parsed if str(topic).strip()]
96
+ return [str(parsed).strip()]
97
+ except json.JSONDecodeError:
98
+ # If not JSON, treat as comma-separated string
99
+ if "," in topic_input:
100
+ return [topic.strip() for topic in topic_input.split(",") if topic.strip()]
101
+ return [topic_input.strip()]
102
+ elif isinstance(topic_input, list):
103
+ return [str(topic).strip() for topic in topic_input if str(topic).strip()]
104
+
105
+ return []
106
+
107
+ def generate_topic_from_trends(trending_topics: List[str]) -> str:
108
+ """Generate a viral topic based on trending topics"""
109
+ if not generator.loaded:
110
+ if not generator.load_model():
111
+ raise Exception(f"Model failed to load: {generator.load_error}")
112
+
113
+ logger.info(f"🧠 Generating viral topic from trends: {trending_topics}")
114
+
115
+ prompt = (
116
+ f"Based on these 5 trending topics: {', '.join(trending_topics)}\n\n"
117
+ "Create ONE highly engaging, viral topic for a YouTube/TikTok short video that:\n"
118
+ "1. Combines elements from these trends in a creative way\n"
119
+ "2. Has high viral potential (emotional, surprising, or controversial)\n"
120
+ "3. Is suitable for a 60-second video format\n"
121
+ "4. Appeals to a broad audience\n"
122
+ "5. Is specific enough to be interesting but broad enough to allow creative interpretation\n\n"
123
+ "Respond with ONLY the topic (no explanations, no bullet points, no numbering).\n"
124
+ "The topic should be 5-10 words maximum.\n\n"
125
+ "VIRAL TOPIC:"
126
+ )
127
+
128
+ inputs = generator.tokenizer(
129
+ prompt,
130
+ return_tensors="pt",
131
+ truncation=True,
132
+ max_length=512
133
+ )
134
+
135
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
136
+
137
+ with torch.no_grad():
138
+ outputs = generator.model.generate(
139
+ **inputs,
140
+ max_new_tokens=100,
141
+ do_sample=True,
142
+ top_p=0.9,
143
+ temperature=0.8,
144
+ pad_token_id=generator.tokenizer.eos_token_id,
145
+ repetition_penalty=1.1
146
+ )
147
+
148
+ generated_text = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
149
+ topic = generated_text.replace(prompt, "").strip()
150
+
151
+ # Clean up the topic
152
+ topic = re.split(r'[\n\.]', topic)[0].strip()
153
+ topic = re.sub(r'^["\'](.*)["\']$', r'\1', topic) # Remove surrounding quotes
154
+
155
+ logger.info(f"🎯 Generated topic: '{topic}'")
156
+ return topic
157
 
158
  def format_script(script: str) -> str:
159
  """Clean and format the generated script"""
 
263
  logger.error(f"❌ Script generation failed: {str(e)}")
264
  raise
265
 
266
+ async def process_job(job_id: str, topics_input: Union[str, List[str]], callback_url: str = None):
267
  """Background task to process job"""
268
  try:
269
+ # Extract and validate topics
270
+ topics = extract_topics(topics_input)
271
+ if len(topics) < 3:
272
+ raise HTTPException(status_code=400, detail="At least 3 topics are required")
273
+
274
+ logger.info(f"🎯 Processing {len(topics)} topics: {topics}")
275
 
276
+ # Step 1: Generate a viral topic from the trends
277
+ generated_topic = generate_topic_from_trends(topics)
278
+
279
+ # Step 2: Generate script based on the created topic
280
+ script = generate_script(generated_topic)
281
 
282
  # Store job results
283
  jobs[job_id] = {
284
  "status": "complete",
285
  "result": script,
286
+ "original_topics": topics,
287
+ "generated_topic": generated_topic,
288
  "script_length": len(script),
289
  "formatted": True
290
  }
 
299
  "job_id": job_id,
300
  "status": "complete",
301
  "result": script,
302
+ "original_topics": topics,
303
+ "generated_topic": generated_topic,
304
  "script_length": len(script),
305
  "formatted": True
306
  }
 
324
  jobs[job_id] = {
325
  "status": "failed",
326
  "error": error_msg,
327
+ "topics": extract_topics(topics_input) if topics_input else []
328
  }
329
 
330
  # Send failure webhook if callback URL exists
 
337
  "job_id": job_id,
338
  "status": "failed",
339
  "error": error_msg,
340
+ "topics": extract_topics(topics_input) if topics_input else []
341
  }
342
  )
343
  except Exception:
 
355
  job_id = str(uuid.uuid4())
356
 
357
  # Validate input
358
+ if not data.get("topics"):
359
+ raise HTTPException(status_code=400, detail="Topics are required")
360
 
361
  callback_url = data.get("callback_url")
362
+ topics_input = data["topics"]
363
+ topics = extract_topics(topics_input)
364
 
365
+ if len(topics) < 3:
366
+ raise HTTPException(status_code=400, detail="At least 3 topics are required")
367
+
368
+ logger.info(f"📥 Received job {job_id} with {len(topics)} topics: {topics}")
369
 
370
  # Store initial job data
371
  jobs[job_id] = {
372
  "status": "processing",
373
  "callback_url": callback_url,
374
+ "topics": topics
375
  }
376
 
377
  # Process job in background
378
  background_tasks.add_task(
379
  process_job,
380
  job_id,
381
+ topics_input,
382
  callback_url
383
  )
384
 
385
  return JSONResponse({
386
  "job_id": job_id,
387
  "status": "queued",
388
+ "topics": topics,
389
+ "estimated_time": "90-120 seconds",
390
+ "message": "Topic generation and script creation started"
391
  })
392
 
393
  except Exception as e:
 
424
  if not generator.load_model():
425
  return {"status": "error", "error": "Model failed to load"}
426
 
427
+ test_topics = [
428
+ "Artificial Intelligence",
429
+ "Sustainable Energy",
430
+ "Virtual Reality",
431
+ "Space Exploration",
432
+ "Biotechnology"
433
+ ]
434
+
435
+ logger.info(f"🧪 Testing topic generation with: {test_topics}")
436
+
437
+ # Test topic generation
438
+ generated_topic = generate_topic_from_trends(test_topics)
439
 
440
+ # Test script generation
441
+ script = generate_script(generated_topic)
442
 
443
  return {
444
  "status": "success",
445
+ "test_topics": test_topics,
446
+ "generated_topic": generated_topic,
447
  "script_length": len(script),
448
  "script_preview": script[:300] + "..." if len(script) > 300 else script,
449
  "estimated_duration": "60 seconds",
 
458
  async def root():
459
  """Root endpoint"""
460
  return {
461
+ "message": "Enhanced Video Script Generator API",
462
+ "version": "2.0",
463
+ "features": "Generates viral topics from trends, then creates scripts",
464
  "endpoints": {
465
+ "submit_job": "POST /api/submit (with 'topics' array)",
466
  "check_status": "GET /api/status/{job_id}",
467
  "health": "GET /health",
468
  "test": "GET /test/generation"