Update app.py
Browse files
app.py
CHANGED
|
@@ -9,13 +9,12 @@ from fastapi.responses import JSONResponse
|
|
| 9 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 10 |
import uvicorn
|
| 11 |
from contextlib import asynccontextmanager
|
| 12 |
-
from pydantic import BaseModel
|
| 13 |
|
| 14 |
# Configuration
|
| 15 |
MODEL_ID = "google/gemma-1.1-2b-it"
|
| 16 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 17 |
-
MAX_TOKENS = 200
|
| 18 |
-
DEVICE = "cpu"
|
| 19 |
PORT = int(os.getenv("PORT", 7860))
|
| 20 |
|
| 21 |
# Setup logging
|
|
@@ -44,12 +43,11 @@ class ScriptGenerator:
|
|
| 44 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 45 |
MODEL_ID,
|
| 46 |
torch_dtype=torch.float32,
|
| 47 |
-
device_map="auto",
|
| 48 |
token=HF_TOKEN,
|
| 49 |
low_cpu_mem_usage=True
|
| 50 |
)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
self.loaded = True
|
| 54 |
logger.info("β
Model loaded successfully")
|
| 55 |
except Exception as e:
|
|
@@ -69,17 +67,15 @@ def extract_topic(topic_input: Union[str, List[str]]) -> str:
|
|
| 69 |
"""Extract topic from string or array input"""
|
| 70 |
if isinstance(topic_input, list):
|
| 71 |
if topic_input:
|
| 72 |
-
return str(topic_input[0])
|
| 73 |
return "No topic provided"
|
| 74 |
return str(topic_input)
|
| 75 |
|
| 76 |
def generate_script(topic: str) -> str:
|
| 77 |
"""Generate script with error handling"""
|
| 78 |
try:
|
| 79 |
-
# Clean the topic input
|
| 80 |
clean_topic = topic.strip().strip("['").strip("']").strip('"').strip("'")
|
| 81 |
-
|
| 82 |
-
logger.info(f"π― Generating script for topic: '{clean_topic}'")
|
| 83 |
|
| 84 |
prompt = (
|
| 85 |
f"Create a short 1-minute video script about: {clean_topic[:80]}\n\n"
|
|
@@ -90,22 +86,13 @@ def generate_script(topic: str) -> str:
|
|
| 90 |
"Script:"
|
| 91 |
)
|
| 92 |
|
| 93 |
-
logger.info(f"π Prompt: {prompt[:100]}...")
|
| 94 |
-
|
| 95 |
-
# Tokenize with proper padding
|
| 96 |
inputs = generator.tokenizer(
|
| 97 |
prompt,
|
| 98 |
return_tensors="pt",
|
| 99 |
padding=True,
|
| 100 |
truncation=True,
|
| 101 |
max_length=512
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
# Move to device
|
| 105 |
-
if DEVICE == "cuda":
|
| 106 |
-
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 107 |
-
else:
|
| 108 |
-
inputs = {k: v for k, v in inputs.items()}
|
| 109 |
|
| 110 |
# Generate with safer parameters
|
| 111 |
with torch.no_grad():
|
|
@@ -115,17 +102,13 @@ def generate_script(topic: str) -> str:
|
|
| 115 |
do_sample=True,
|
| 116 |
top_p=0.8,
|
| 117 |
temperature=0.7,
|
| 118 |
-
pad_token_id=generator.tokenizer.eos_token_id
|
| 119 |
-
repetition_penalty=1.1
|
| 120 |
)
|
| 121 |
|
| 122 |
-
# Decode the output
|
| 123 |
script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 124 |
clean_script = script.replace(prompt, "").strip()
|
| 125 |
|
| 126 |
-
logger.info(f"π Generated
|
| 127 |
-
logger.info(f"π Script length: {len(clean_script)} characters")
|
| 128 |
-
|
| 129 |
return clean_script
|
| 130 |
|
| 131 |
except Exception as e:
|
|
@@ -135,9 +118,8 @@ def generate_script(topic: str) -> str:
|
|
| 135 |
async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_url: str = None):
|
| 136 |
"""Background task to process job"""
|
| 137 |
try:
|
| 138 |
-
# Extract and clean the topic
|
| 139 |
topic = extract_topic(topic_input)
|
| 140 |
-
logger.info(f"π― Processing
|
| 141 |
|
| 142 |
script = generate_script(topic)
|
| 143 |
jobs[job_id] = {
|
|
@@ -148,20 +130,17 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
|
|
| 148 |
"script_length": len(script)
|
| 149 |
}
|
| 150 |
|
| 151 |
-
logger.info(f"β
Completed job {job_id}
|
| 152 |
|
| 153 |
if callback_url:
|
| 154 |
try:
|
| 155 |
-
logger.info(f"π€ Sending webhook to: {callback_url}")
|
| 156 |
-
|
| 157 |
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 158 |
webhook_data = {
|
| 159 |
"job_id": job_id,
|
| 160 |
"status": "complete",
|
| 161 |
"result": script,
|
| 162 |
"topic": topic,
|
| 163 |
-
"original_input": topic_input
|
| 164 |
-
"script_length": len(script)
|
| 165 |
}
|
| 166 |
|
| 167 |
response = await client.post(
|
|
@@ -170,17 +149,10 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
|
|
| 170 |
headers={"Content-Type": "application/json"}
|
| 171 |
)
|
| 172 |
|
| 173 |
-
logger.info(f"π¨ Webhook
|
| 174 |
-
|
| 175 |
-
if response.status_code == 200:
|
| 176 |
-
logger.info(f"β
Webhook delivered successfully to n8n")
|
| 177 |
-
else:
|
| 178 |
-
logger.error(f"β Webhook failed with status: {response.status_code}")
|
| 179 |
|
| 180 |
except Exception as e:
|
| 181 |
logger.error(f"β Webhook failed: {str(e)}")
|
| 182 |
-
else:
|
| 183 |
-
logger.warning(f"β οΈ No callback URL provided for job {job_id}")
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
error_msg = f"Job failed: {str(e)}"
|
|
@@ -201,20 +173,14 @@ async def submit_job(request: Request, background_tasks: BackgroundTasks):
|
|
| 201 |
data = await request.json()
|
| 202 |
job_id = str(uuid.uuid4())
|
| 203 |
|
| 204 |
-
# Validate input
|
| 205 |
if not data.get("topic"):
|
| 206 |
raise HTTPException(status_code=400, detail="Topic is required")
|
| 207 |
|
| 208 |
callback_url = data.get("callback_url")
|
| 209 |
topic_input = data["topic"]
|
| 210 |
-
|
| 211 |
-
# Extract and log the topic
|
| 212 |
topic = extract_topic(topic_input)
|
| 213 |
|
| 214 |
-
logger.info(f"π₯ Received
|
| 215 |
-
logger.info(f"π Raw input: {topic_input}")
|
| 216 |
-
logger.info(f"π― Cleaned topic: '{topic}'")
|
| 217 |
-
logger.info(f"π Callback URL: {callback_url or 'None'}")
|
| 218 |
|
| 219 |
jobs[job_id] = {
|
| 220 |
"status": "processing",
|
|
@@ -235,9 +201,7 @@ async def submit_job(request: Request, background_tasks: BackgroundTasks):
|
|
| 235 |
"job_id": job_id,
|
| 236 |
"status": "queued",
|
| 237 |
"received_topic": topic,
|
| 238 |
-
"
|
| 239 |
-
"callback_url": callback_url,
|
| 240 |
-
"message": "Job is being processed"
|
| 241 |
})
|
| 242 |
|
| 243 |
except Exception as e:
|
|
@@ -260,9 +224,7 @@ async def debug_jobs():
|
|
| 260 |
job_id: {
|
| 261 |
"status": data["status"],
|
| 262 |
"topic": data.get("topic", "unknown"),
|
| 263 |
-
"original_input": data.get("original_input", "unknown"),
|
| 264 |
"script_length": data.get("script_length", 0),
|
| 265 |
-
"callback_url": data.get("callback_url"),
|
| 266 |
"error": data.get("error", "none")
|
| 267 |
}
|
| 268 |
for job_id, data in jobs.items()
|
|
@@ -275,32 +237,18 @@ async def health_check():
|
|
| 275 |
return {
|
| 276 |
"status": "healthy",
|
| 277 |
"model_loaded": generator.loaded,
|
| 278 |
-
"
|
| 279 |
-
"completed_jobs": sum(1 for job in jobs.values() if job.get("status") == "complete"),
|
| 280 |
-
"failed_jobs": sum(1 for job in jobs.values() if job.get("status") == "failed")
|
| 281 |
}
|
| 282 |
|
| 283 |
@app.get("/test/generation")
|
| 284 |
async def test_generation():
|
| 285 |
-
"""Test
|
| 286 |
try:
|
| 287 |
test_topic = "healthy lifestyle tips"
|
| 288 |
-
logger.info(f"π§ͺ Testing generation with topic: {test_topic}")
|
| 289 |
-
|
| 290 |
script = generate_script(test_topic)
|
| 291 |
-
|
| 292 |
-
return {
|
| 293 |
-
"status": "success",
|
| 294 |
-
"topic": test_topic,
|
| 295 |
-
"script": script,
|
| 296 |
-
"length": len(script)
|
| 297 |
-
}
|
| 298 |
except Exception as e:
|
| 299 |
-
|
| 300 |
-
return {
|
| 301 |
-
"status": "error",
|
| 302 |
-
"error": str(e)
|
| 303 |
-
}
|
| 304 |
|
| 305 |
if __name__ == "__main__":
|
| 306 |
uvicorn.run(
|
|
|
|
| 9 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 10 |
import uvicorn
|
| 11 |
from contextlib import asynccontextmanager
|
|
|
|
| 12 |
|
| 13 |
# Configuration
|
| 14 |
MODEL_ID = "google/gemma-1.1-2b-it"
|
| 15 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 16 |
+
MAX_TOKENS = 200
|
| 17 |
+
DEVICE = "cpu" # Force CPU to avoid device_map issues
|
| 18 |
PORT = int(os.getenv("PORT", 7860))
|
| 19 |
|
| 20 |
# Setup logging
|
|
|
|
| 43 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 44 |
MODEL_ID,
|
| 45 |
torch_dtype=torch.float32,
|
|
|
|
| 46 |
token=HF_TOKEN,
|
| 47 |
low_cpu_mem_usage=True
|
| 48 |
)
|
| 49 |
+
# Simple device assignment without device_map
|
| 50 |
+
self.model = self.model.to(DEVICE)
|
| 51 |
self.loaded = True
|
| 52 |
logger.info("β
Model loaded successfully")
|
| 53 |
except Exception as e:
|
|
|
|
| 67 |
"""Extract topic from string or array input"""
|
| 68 |
if isinstance(topic_input, list):
|
| 69 |
if topic_input:
|
| 70 |
+
return str(topic_input[0])
|
| 71 |
return "No topic provided"
|
| 72 |
return str(topic_input)
|
| 73 |
|
| 74 |
def generate_script(topic: str) -> str:
|
| 75 |
"""Generate script with error handling"""
|
| 76 |
try:
|
|
|
|
| 77 |
clean_topic = topic.strip().strip("['").strip("']").strip('"').strip("'")
|
| 78 |
+
logger.info(f"π― Generating script for: '{clean_topic}'")
|
|
|
|
| 79 |
|
| 80 |
prompt = (
|
| 81 |
f"Create a short 1-minute video script about: {clean_topic[:80]}\n\n"
|
|
|
|
| 86 |
"Script:"
|
| 87 |
)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
inputs = generator.tokenizer(
|
| 90 |
prompt,
|
| 91 |
return_tensors="pt",
|
| 92 |
padding=True,
|
| 93 |
truncation=True,
|
| 94 |
max_length=512
|
| 95 |
+
).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
# Generate with safer parameters
|
| 98 |
with torch.no_grad():
|
|
|
|
| 102 |
do_sample=True,
|
| 103 |
top_p=0.8,
|
| 104 |
temperature=0.7,
|
| 105 |
+
pad_token_id=generator.tokenizer.eos_token_id
|
|
|
|
| 106 |
)
|
| 107 |
|
|
|
|
| 108 |
script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 109 |
clean_script = script.replace(prompt, "").strip()
|
| 110 |
|
| 111 |
+
logger.info(f"π Generated {len(clean_script)} characters")
|
|
|
|
|
|
|
| 112 |
return clean_script
|
| 113 |
|
| 114 |
except Exception as e:
|
|
|
|
| 118 |
async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_url: str = None):
|
| 119 |
"""Background task to process job"""
|
| 120 |
try:
|
|
|
|
| 121 |
topic = extract_topic(topic_input)
|
| 122 |
+
logger.info(f"π― Processing: '{topic}'")
|
| 123 |
|
| 124 |
script = generate_script(topic)
|
| 125 |
jobs[job_id] = {
|
|
|
|
| 130 |
"script_length": len(script)
|
| 131 |
}
|
| 132 |
|
| 133 |
+
logger.info(f"β
Completed job {job_id}")
|
| 134 |
|
| 135 |
if callback_url:
|
| 136 |
try:
|
|
|
|
|
|
|
| 137 |
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 138 |
webhook_data = {
|
| 139 |
"job_id": job_id,
|
| 140 |
"status": "complete",
|
| 141 |
"result": script,
|
| 142 |
"topic": topic,
|
| 143 |
+
"original_input": topic_input
|
|
|
|
| 144 |
}
|
| 145 |
|
| 146 |
response = await client.post(
|
|
|
|
| 149 |
headers={"Content-Type": "application/json"}
|
| 150 |
)
|
| 151 |
|
| 152 |
+
logger.info(f"π¨ Webhook status: {response.status_code}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
except Exception as e:
|
| 155 |
logger.error(f"β Webhook failed: {str(e)}")
|
|
|
|
|
|
|
| 156 |
|
| 157 |
except Exception as e:
|
| 158 |
error_msg = f"Job failed: {str(e)}"
|
|
|
|
| 173 |
data = await request.json()
|
| 174 |
job_id = str(uuid.uuid4())
|
| 175 |
|
|
|
|
| 176 |
if not data.get("topic"):
|
| 177 |
raise HTTPException(status_code=400, detail="Topic is required")
|
| 178 |
|
| 179 |
callback_url = data.get("callback_url")
|
| 180 |
topic_input = data["topic"]
|
|
|
|
|
|
|
| 181 |
topic = extract_topic(topic_input)
|
| 182 |
|
| 183 |
+
logger.info(f"π₯ Received job {job_id}: '{topic}'")
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
jobs[job_id] = {
|
| 186 |
"status": "processing",
|
|
|
|
| 201 |
"job_id": job_id,
|
| 202 |
"status": "queued",
|
| 203 |
"received_topic": topic,
|
| 204 |
+
"callback_url": callback_url
|
|
|
|
|
|
|
| 205 |
})
|
| 206 |
|
| 207 |
except Exception as e:
|
|
|
|
| 224 |
job_id: {
|
| 225 |
"status": data["status"],
|
| 226 |
"topic": data.get("topic", "unknown"),
|
|
|
|
| 227 |
"script_length": data.get("script_length", 0),
|
|
|
|
| 228 |
"error": data.get("error", "none")
|
| 229 |
}
|
| 230 |
for job_id, data in jobs.items()
|
|
|
|
| 237 |
return {
|
| 238 |
"status": "healthy",
|
| 239 |
"model_loaded": generator.loaded,
|
| 240 |
+
"total_jobs": len(jobs)
|
|
|
|
|
|
|
| 241 |
}
|
| 242 |
|
| 243 |
@app.get("/test/generation")
|
| 244 |
async def test_generation():
|
| 245 |
+
"""Test script generation"""
|
| 246 |
try:
|
| 247 |
test_topic = "healthy lifestyle tips"
|
|
|
|
|
|
|
| 248 |
script = generate_script(test_topic)
|
| 249 |
+
return {"status": "success", "topic": test_topic, "script": script}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
except Exception as e:
|
| 251 |
+
return {"status": "error", "error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
if __name__ == "__main__":
|
| 254 |
uvicorn.run(
|