yukee1992 commited on
Commit
aa364cd
Β·
verified Β·
1 Parent(s): 386c1c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -71
app.py CHANGED
@@ -9,13 +9,12 @@ from fastapi.responses import JSONResponse
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  import uvicorn
11
  from contextlib import asynccontextmanager
12
- from pydantic import BaseModel
13
 
14
  # Configuration
15
  MODEL_ID = "google/gemma-1.1-2b-it"
16
  HF_TOKEN = os.getenv("HF_TOKEN", "")
17
- MAX_TOKENS = 200 # Reduced for faster generation
18
- DEVICE = "cpu"
19
  PORT = int(os.getenv("PORT", 7860))
20
 
21
  # Setup logging
@@ -44,12 +43,11 @@ class ScriptGenerator:
44
  self.model = AutoModelForCausalLM.from_pretrained(
45
  MODEL_ID,
46
  torch_dtype=torch.float32,
47
- device_map="auto",
48
  token=HF_TOKEN,
49
  low_cpu_mem_usage=True
50
  )
51
- if DEVICE == "cuda":
52
- self.model = self.model.cuda()
53
  self.loaded = True
54
  logger.info("βœ… Model loaded successfully")
55
  except Exception as e:
@@ -69,17 +67,15 @@ def extract_topic(topic_input: Union[str, List[str]]) -> str:
69
  """Extract topic from string or array input"""
70
  if isinstance(topic_input, list):
71
  if topic_input:
72
- return str(topic_input[0]) # Take first element if it's a list
73
  return "No topic provided"
74
  return str(topic_input)
75
 
76
  def generate_script(topic: str) -> str:
77
  """Generate script with error handling"""
78
  try:
79
- # Clean the topic input
80
  clean_topic = topic.strip().strip("['").strip("']").strip('"').strip("'")
81
-
82
- logger.info(f"🎯 Generating script for topic: '{clean_topic}'")
83
 
84
  prompt = (
85
  f"Create a short 1-minute video script about: {clean_topic[:80]}\n\n"
@@ -90,22 +86,13 @@ def generate_script(topic: str) -> str:
90
  "Script:"
91
  )
92
 
93
- logger.info(f"πŸ“‹ Prompt: {prompt[:100]}...")
94
-
95
- # Tokenize with proper padding
96
  inputs = generator.tokenizer(
97
  prompt,
98
  return_tensors="pt",
99
  padding=True,
100
  truncation=True,
101
  max_length=512
102
- )
103
-
104
- # Move to device
105
- if DEVICE == "cuda":
106
- inputs = {k: v.cuda() for k, v in inputs.items()}
107
- else:
108
- inputs = {k: v for k, v in inputs.items()}
109
 
110
  # Generate with safer parameters
111
  with torch.no_grad():
@@ -115,17 +102,13 @@ def generate_script(topic: str) -> str:
115
  do_sample=True,
116
  top_p=0.8,
117
  temperature=0.7,
118
- pad_token_id=generator.tokenizer.eos_token_id,
119
- repetition_penalty=1.1
120
  )
121
 
122
- # Decode the output
123
  script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
124
  clean_script = script.replace(prompt, "").strip()
125
 
126
- logger.info(f"πŸ“ Generated script: {clean_script[:100]}...")
127
- logger.info(f"πŸ“ Script length: {len(clean_script)} characters")
128
-
129
  return clean_script
130
 
131
  except Exception as e:
@@ -135,9 +118,8 @@ def generate_script(topic: str) -> str:
135
  async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_url: str = None):
136
  """Background task to process job"""
137
  try:
138
- # Extract and clean the topic
139
  topic = extract_topic(topic_input)
140
- logger.info(f"🎯 Processing topic: '{topic}'")
141
 
142
  script = generate_script(topic)
143
  jobs[job_id] = {
@@ -148,20 +130,17 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
148
  "script_length": len(script)
149
  }
150
 
151
- logger.info(f"βœ… Completed job {job_id} - Script generated successfully")
152
 
153
  if callback_url:
154
  try:
155
- logger.info(f"πŸ“€ Sending webhook to: {callback_url}")
156
-
157
  async with httpx.AsyncClient(timeout=30.0) as client:
158
  webhook_data = {
159
  "job_id": job_id,
160
  "status": "complete",
161
  "result": script,
162
  "topic": topic,
163
- "original_input": topic_input,
164
- "script_length": len(script)
165
  }
166
 
167
  response = await client.post(
@@ -170,17 +149,10 @@ async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_
170
  headers={"Content-Type": "application/json"}
171
  )
172
 
173
- logger.info(f"πŸ“¨ Webhook response status: {response.status_code}")
174
-
175
- if response.status_code == 200:
176
- logger.info(f"βœ… Webhook delivered successfully to n8n")
177
- else:
178
- logger.error(f"❌ Webhook failed with status: {response.status_code}")
179
 
180
  except Exception as e:
181
  logger.error(f"❌ Webhook failed: {str(e)}")
182
- else:
183
- logger.warning(f"⚠️ No callback URL provided for job {job_id}")
184
 
185
  except Exception as e:
186
  error_msg = f"Job failed: {str(e)}"
@@ -201,20 +173,14 @@ async def submit_job(request: Request, background_tasks: BackgroundTasks):
201
  data = await request.json()
202
  job_id = str(uuid.uuid4())
203
 
204
- # Validate input
205
  if not data.get("topic"):
206
  raise HTTPException(status_code=400, detail="Topic is required")
207
 
208
  callback_url = data.get("callback_url")
209
  topic_input = data["topic"]
210
-
211
- # Extract and log the topic
212
  topic = extract_topic(topic_input)
213
 
214
- logger.info(f"πŸ“₯ Received new job - ID: {job_id}")
215
- logger.info(f"πŸ“ Raw input: {topic_input}")
216
- logger.info(f"🎯 Cleaned topic: '{topic}'")
217
- logger.info(f"πŸ”— Callback URL: {callback_url or 'None'}")
218
 
219
  jobs[job_id] = {
220
  "status": "processing",
@@ -235,9 +201,7 @@ async def submit_job(request: Request, background_tasks: BackgroundTasks):
235
  "job_id": job_id,
236
  "status": "queued",
237
  "received_topic": topic,
238
- "original_input": topic_input,
239
- "callback_url": callback_url,
240
- "message": "Job is being processed"
241
  })
242
 
243
  except Exception as e:
@@ -260,9 +224,7 @@ async def debug_jobs():
260
  job_id: {
261
  "status": data["status"],
262
  "topic": data.get("topic", "unknown"),
263
- "original_input": data.get("original_input", "unknown"),
264
  "script_length": data.get("script_length", 0),
265
- "callback_url": data.get("callback_url"),
266
  "error": data.get("error", "none")
267
  }
268
  for job_id, data in jobs.items()
@@ -275,32 +237,18 @@ async def health_check():
275
  return {
276
  "status": "healthy",
277
  "model_loaded": generator.loaded,
278
- "total_jobs_processed": len(jobs),
279
- "completed_jobs": sum(1 for job in jobs.values() if job.get("status") == "complete"),
280
- "failed_jobs": sum(1 for job in jobs.values() if job.get("status") == "failed")
281
  }
282
 
283
  @app.get("/test/generation")
284
  async def test_generation():
285
- """Test endpoint to verify script generation works"""
286
  try:
287
  test_topic = "healthy lifestyle tips"
288
- logger.info(f"πŸ§ͺ Testing generation with topic: {test_topic}")
289
-
290
  script = generate_script(test_topic)
291
-
292
- return {
293
- "status": "success",
294
- "topic": test_topic,
295
- "script": script,
296
- "length": len(script)
297
- }
298
  except Exception as e:
299
- logger.error(f"❌ Test generation failed: {str(e)}", exc_info=True)
300
- return {
301
- "status": "error",
302
- "error": str(e)
303
- }
304
 
305
  if __name__ == "__main__":
306
  uvicorn.run(
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  import uvicorn
11
  from contextlib import asynccontextmanager
 
12
 
13
  # Configuration
14
  MODEL_ID = "google/gemma-1.1-2b-it"
15
  HF_TOKEN = os.getenv("HF_TOKEN", "")
16
+ MAX_TOKENS = 200
17
+ DEVICE = "cpu" # Force CPU to avoid device_map issues
18
  PORT = int(os.getenv("PORT", 7860))
19
 
20
  # Setup logging
 
43
  self.model = AutoModelForCausalLM.from_pretrained(
44
  MODEL_ID,
45
  torch_dtype=torch.float32,
 
46
  token=HF_TOKEN,
47
  low_cpu_mem_usage=True
48
  )
49
+ # Simple device assignment without device_map
50
+ self.model = self.model.to(DEVICE)
51
  self.loaded = True
52
  logger.info("βœ… Model loaded successfully")
53
  except Exception as e:
 
67
  """Extract topic from string or array input"""
68
  if isinstance(topic_input, list):
69
  if topic_input:
70
+ return str(topic_input[0])
71
  return "No topic provided"
72
  return str(topic_input)
73
 
74
  def generate_script(topic: str) -> str:
75
  """Generate script with error handling"""
76
  try:
 
77
  clean_topic = topic.strip().strip("['").strip("']").strip('"').strip("'")
78
+ logger.info(f"🎯 Generating script for: '{clean_topic}'")
 
79
 
80
  prompt = (
81
  f"Create a short 1-minute video script about: {clean_topic[:80]}\n\n"
 
86
  "Script:"
87
  )
88
 
 
 
 
89
  inputs = generator.tokenizer(
90
  prompt,
91
  return_tensors="pt",
92
  padding=True,
93
  truncation=True,
94
  max_length=512
95
+ ).to(DEVICE)
 
 
 
 
 
 
96
 
97
  # Generate with safer parameters
98
  with torch.no_grad():
 
102
  do_sample=True,
103
  top_p=0.8,
104
  temperature=0.7,
105
+ pad_token_id=generator.tokenizer.eos_token_id
 
106
  )
107
 
 
108
  script = generator.tokenizer.decode(outputs[0], skip_special_tokens=True)
109
  clean_script = script.replace(prompt, "").strip()
110
 
111
+ logger.info(f"πŸ“ Generated {len(clean_script)} characters")
 
 
112
  return clean_script
113
 
114
  except Exception as e:
 
118
  async def process_job(job_id: str, topic_input: Union[str, List[str]], callback_url: str = None):
119
  """Background task to process job"""
120
  try:
 
121
  topic = extract_topic(topic_input)
122
+ logger.info(f"🎯 Processing: '{topic}'")
123
 
124
  script = generate_script(topic)
125
  jobs[job_id] = {
 
130
  "script_length": len(script)
131
  }
132
 
133
+ logger.info(f"βœ… Completed job {job_id}")
134
 
135
  if callback_url:
136
  try:
 
 
137
  async with httpx.AsyncClient(timeout=30.0) as client:
138
  webhook_data = {
139
  "job_id": job_id,
140
  "status": "complete",
141
  "result": script,
142
  "topic": topic,
143
+ "original_input": topic_input
 
144
  }
145
 
146
  response = await client.post(
 
149
  headers={"Content-Type": "application/json"}
150
  )
151
 
152
+ logger.info(f"πŸ“¨ Webhook status: {response.status_code}")
 
 
 
 
 
153
 
154
  except Exception as e:
155
  logger.error(f"❌ Webhook failed: {str(e)}")
 
 
156
 
157
  except Exception as e:
158
  error_msg = f"Job failed: {str(e)}"
 
173
  data = await request.json()
174
  job_id = str(uuid.uuid4())
175
 
 
176
  if not data.get("topic"):
177
  raise HTTPException(status_code=400, detail="Topic is required")
178
 
179
  callback_url = data.get("callback_url")
180
  topic_input = data["topic"]
 
 
181
  topic = extract_topic(topic_input)
182
 
183
+ logger.info(f"πŸ“₯ Received job {job_id}: '{topic}'")
 
 
 
184
 
185
  jobs[job_id] = {
186
  "status": "processing",
 
201
  "job_id": job_id,
202
  "status": "queued",
203
  "received_topic": topic,
204
+ "callback_url": callback_url
 
 
205
  })
206
 
207
  except Exception as e:
 
224
  job_id: {
225
  "status": data["status"],
226
  "topic": data.get("topic", "unknown"),
 
227
  "script_length": data.get("script_length", 0),
 
228
  "error": data.get("error", "none")
229
  }
230
  for job_id, data in jobs.items()
 
237
  return {
238
  "status": "healthy",
239
  "model_loaded": generator.loaded,
240
+ "total_jobs": len(jobs)
 
 
241
  }
242
 
243
  @app.get("/test/generation")
244
  async def test_generation():
245
+ """Test script generation"""
246
  try:
247
  test_topic = "healthy lifestyle tips"
 
 
248
  script = generate_script(test_topic)
249
+ return {"status": "success", "topic": test_topic, "script": script}
 
 
 
 
 
 
250
  except Exception as e:
251
+ return {"status": "error", "error": str(e)}
 
 
 
 
252
 
253
  if __name__ == "__main__":
254
  uvicorn.run(