matsuap commited on
Commit
b6e32c9
·
verified ·
1 Parent(s): 951d5c6

Upload folder using huggingface_hub

Browse files
api/flashcards.py CHANGED
@@ -1,74 +1,46 @@
1
  import logging
2
- from fastapi import APIRouter, Depends, HTTPException
3
  from sqlalchemy.orm import Session
4
  from typing import List, Dict
 
 
5
 
6
  from api.auth import get_current_user
7
  from models import db_models
8
  from models.schemas import FlashcardGenerateRequest, FlashcardSetResponse, FlashcardResponse
9
- from core.database import get_db
 
10
  from services.flashcard_service import flashcard_service
11
  from core import constants
12
 
13
  router = APIRouter(prefix="/api/flashcards", tags=["flashcards"])
14
  logger = logging.getLogger(__name__)
15
 
16
- @router.get("/config")
17
- async def get_flashcard_config():
18
- """Returns available difficulties, quantities, and languages for flashcards."""
19
- return {
20
- "difficulties": constants.DIFFICULTIES,
21
- "quantities": constants.FLASHCARD_QUANTITIES,
22
- "languages": constants.LANGUAGES
23
- }
24
-
25
- @router.post("/generate", response_model=FlashcardSetResponse)
26
- async def generate_flashcards(
27
- request: FlashcardGenerateRequest,
28
- current_user: db_models.User = Depends(get_current_user),
29
- db: Session = Depends(get_db)
30
- ):
31
- """
32
- Generates a set of flashcards and saves them to the database.
33
- """
34
  try:
35
- source_id = None
36
- if request.file_key:
37
- # Verify file ownership
38
- source = db.query(db_models.Source).filter(
39
- db_models.Source.s3_key == request.file_key,
40
- db_models.Source.user_id == current_user.id
41
- ).first()
42
- if not source:
43
- raise HTTPException(status_code=403, detail="Not authorized to access this file")
44
- source_id = source.id
45
-
46
- # 1. Generate Flashcards from AI
47
  cards_data = await flashcard_service.generate_flashcards(
48
  file_key=request.file_key,
49
  text_input=request.text_input,
50
  difficulty=request.difficulty,
51
  quantity=request.quantity,
52
  topic=request.topic,
53
- language=request.language
 
 
 
54
  )
55
 
56
  if not cards_data:
57
- raise HTTPException(status_code=500, detail="AI returned an empty response")
58
 
59
- # 2. Save Flashcard Set to DB
60
- title = request.topic if request.topic else f"Flashcards {len(cards_data)}"
61
- db_set = db_models.FlashcardSet(
62
- title=title,
63
- difficulty=request.difficulty,
64
- user_id=current_user.id,
65
- source_id=source_id
66
- )
67
- db.add(db_set)
68
- db.commit()
69
- db.refresh(db_set)
70
-
71
- # 3. Save individual flashcards
72
  for item in cards_data:
73
  db_card = db_models.Flashcard(
74
  flashcard_set_id=db_set.id,
@@ -77,24 +49,74 @@ async def generate_flashcards(
77
  )
78
  db.add(db_card)
79
 
 
80
  db.commit()
81
- db.refresh(db_set)
82
 
83
- return {
 
 
84
  "id": db_set.id,
85
- "title": db_set.title,
86
- "difficulty": db_set.difficulty,
87
- "created_at": db_set.created_at,
88
- "parent_file_id": db_set.source_id,
89
- "parent_file_key": source.s3_key if source else None,
90
- "flashcards": db_set.flashcards
91
- }
92
-
93
- except HTTPException:
94
- raise
95
  except Exception as e:
96
- logger.error(f"Flashcard generation endpoint failed: {e}")
97
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  @router.get("/sets", response_model=List[FlashcardSetResponse])
100
  async def list_flashcard_sets(
@@ -108,18 +130,7 @@ async def list_flashcard_sets(
108
  sets = db.query(db_models.FlashcardSet).filter(
109
  db_models.FlashcardSet.user_id == current_user.id
110
  ).order_by(db_models.FlashcardSet.created_at.desc()).all()
111
- return [
112
- {
113
- "id": s.id,
114
- "title": s.title,
115
- "difficulty": s.difficulty,
116
- "created_at": s.created_at,
117
- "parent_file_id": s.source_id,
118
- "parent_file_key": s.source.s3_key if s.source else None,
119
- "flashcards": s.flashcards
120
- }
121
- for s in sets
122
- ]
123
  except Exception as e:
124
  raise HTTPException(status_code=500, detail=str(e))
125
 
@@ -140,15 +151,7 @@ async def get_flashcard_set(
140
  if not db_set:
141
  raise HTTPException(status_code=404, detail="Flashcard set not found")
142
 
143
- return {
144
- "id": db_set.id,
145
- "title": db_set.title,
146
- "difficulty": db_set.difficulty,
147
- "created_at": db_set.created_at,
148
- "parent_file_id": db_set.source_id,
149
- "parent_file_key": db_set.source.s3_key if db_set.source else None,
150
- "flashcards": db_set.flashcards
151
- }
152
 
153
  @router.post("/explain")
154
  async def explain_flashcard(
 
1
  import logging
2
+ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
3
  from sqlalchemy.orm import Session
4
  from typing import List, Dict
5
+ import asyncio
6
+ from datetime import datetime
7
 
8
  from api.auth import get_current_user
9
  from models import db_models
10
  from models.schemas import FlashcardGenerateRequest, FlashcardSetResponse, FlashcardResponse
11
+ from core.database import get_db, SessionLocal
12
+ from api.websocket_routes import manager
13
  from services.flashcard_service import flashcard_service
14
  from core import constants
15
 
16
  router = APIRouter(prefix="/api/flashcards", tags=["flashcards"])
17
  logger = logging.getLogger(__name__)
18
 
19
+ async def run_flashcard_generation(set_id: int, request: FlashcardGenerateRequest, user_id: int):
20
+ """Background task for flashcard generation"""
21
+ db = SessionLocal()
22
+ connection_id = f"user_{user_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  try:
24
+ db_set = db.query(db_models.FlashcardSet).filter(db_models.FlashcardSet.id == set_id).first()
25
+ if not db_set: return
26
+
27
+ # Call AI service
 
 
 
 
 
 
 
 
28
  cards_data = await flashcard_service.generate_flashcards(
29
  file_key=request.file_key,
30
  text_input=request.text_input,
31
  difficulty=request.difficulty,
32
  quantity=request.quantity,
33
  topic=request.topic,
34
+ language=request.language,
35
+ progress_callback=lambda p, m: asyncio.create_task(
36
+ manager.send_progress(connection_id, p, "processing", m)
37
+ )
38
  )
39
 
40
  if not cards_data:
41
+ raise Exception("AI returned empty flashcards data")
42
 
43
+ # Save individual cards
 
 
 
 
 
 
 
 
 
 
 
 
44
  for item in cards_data:
45
  db_card = db_models.Flashcard(
46
  flashcard_set_id=db_set.id,
 
49
  )
50
  db.add(db_card)
51
 
52
+ db_set.status = "completed"
53
  db.commit()
 
54
 
55
+ # Notify via WebSocket
56
+ await manager.send_result(connection_id, {
57
+ "type": "flashcards",
58
  "id": db_set.id,
59
+ "status": "completed",
60
+ "title": db_set.title
61
+ })
62
+
 
 
 
 
 
 
63
  except Exception as e:
64
+ logger.error(f"Background flashcard generation failed: {e}")
65
+ db_set = db.query(db_models.FlashcardSet).filter(db_models.FlashcardSet.id == set_id).first()
66
+ if db_set:
67
+ db_set.status = "failed"
68
+ db_set.error_message = str(e)
69
+ db.commit()
70
+ await manager.send_error(connection_id, f"Flashcard generation failed: {str(e)}")
71
+ finally:
72
+ db.close()
73
+
74
+ @router.get("/config")
75
+ async def get_flashcard_config():
76
+ """Returns available difficulties, quantities, and languages for flashcards."""
77
+ return {
78
+ "difficulties": constants.DIFFICULTIES,
79
+ "quantities": constants.FLASHCARD_QUANTITIES,
80
+ "languages": constants.LANGUAGES
81
+ }
82
+
83
+ @router.post("/generate", response_model=FlashcardSetResponse)
84
+ async def generate_flashcards(
85
+ request: FlashcardGenerateRequest,
86
+ background_tasks: BackgroundTasks,
87
+ current_user: db_models.User = Depends(get_current_user),
88
+ db: Session = Depends(get_db)
89
+ ):
90
+ """
91
+ Initiates flashcard generation in the background.
92
+ """
93
+ source_id = None
94
+ if request.file_key:
95
+ source = db.query(db_models.Source).filter(
96
+ db_models.Source.s3_key == request.file_key,
97
+ db_models.Source.user_id == current_user.id
98
+ ).first()
99
+ if not source:
100
+ raise HTTPException(status_code=403, detail="Not authorized to access this file")
101
+ source_id = source.id
102
+
103
+ # Create initial processing record
104
+ title = request.topic if request.topic else f"Flashcards {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
105
+ db_set = db_models.FlashcardSet(
106
+ title=title,
107
+ difficulty=request.difficulty,
108
+ user_id=current_user.id,
109
+ source_id=source_id,
110
+ status="processing"
111
+ )
112
+ db.add(db_set)
113
+ db.commit()
114
+ db.refresh(db_set)
115
+
116
+ # Offload to background task
117
+ background_tasks.add_task(run_flashcard_generation, db_set.id, request, current_user.id)
118
+
119
+ return db_set
120
 
121
  @router.get("/sets", response_model=List[FlashcardSetResponse])
122
  async def list_flashcard_sets(
 
130
  sets = db.query(db_models.FlashcardSet).filter(
131
  db_models.FlashcardSet.user_id == current_user.id
132
  ).order_by(db_models.FlashcardSet.created_at.desc()).all()
133
+ return [FlashcardSetResponse.model_validate(s) for s in sets]
 
 
 
 
 
 
 
 
 
 
 
134
  except Exception as e:
135
  raise HTTPException(status_code=500, detail=str(e))
136
 
 
151
  if not db_set:
152
  raise HTTPException(status_code=404, detail="Flashcard set not found")
153
 
154
+ return FlashcardSetResponse.model_validate(db_set)
 
 
 
 
 
 
 
 
155
 
156
  @router.post("/explain")
157
  async def explain_flashcard(
api/mindmaps.py CHANGED
@@ -1,74 +1,98 @@
1
  import logging
2
- from fastapi import APIRouter, Depends, HTTPException
3
  from sqlalchemy.orm import Session
4
  from typing import List
 
5
 
6
  from api.auth import get_current_user
7
  from models import db_models
8
  from models.schemas import MindMapGenerateRequest, MindMapResponse
9
- from core.database import get_db
 
10
  from services.mindmap_service import mindmap_service
11
 
12
  router = APIRouter(prefix="/api/mindmaps", tags=["mindmaps"])
13
  logger = logging.getLogger(__name__)
14
 
15
- @router.post("/generate", response_model=MindMapResponse)
16
- async def generate_mindmap(
17
- request: MindMapGenerateRequest,
18
- current_user: db_models.User = Depends(get_current_user),
19
- db: Session = Depends(get_db)
20
- ):
21
- """
22
- Generates a mind map in Mermaid format and saves it to the database.
23
- """
24
  try:
25
- source_id = None
26
- if request.file_key:
27
- # Verify file ownership
28
- source = db.query(db_models.Source).filter(
29
- db_models.Source.s3_key == request.file_key,
30
- db_models.Source.user_id == current_user.id
31
- ).first()
32
- if not source:
33
- raise HTTPException(status_code=403, detail="Not authorized to access this file")
34
- source_id = source.id
35
 
36
- # 1. Generate Mind Map from AI
37
  mermaid_code = await mindmap_service.generate_mindmap(
38
  file_key=request.file_key,
39
  text_input=request.text_input
40
  )
41
 
42
  if not mermaid_code:
43
- raise HTTPException(status_code=500, detail="Failed to generate mind map")
44
 
45
- # 2. Save to DB
46
- title = request.title if request.title else (request.file_key.split('/')[-1] if request.file_key else "Untitled Mind Map")
47
- db_mindmap = db_models.MindMap(
48
- title=title,
49
- mermaid_code=mermaid_code,
50
- user_id=current_user.id,
51
- source_id=source_id
52
- )
53
- db.add(db_mindmap)
54
  db.commit()
55
- db.refresh(db_mindmap)
56
 
57
- return MindMapResponse(
58
- id=db_mindmap.id,
59
- title=db_mindmap.title,
60
- mermaid_code=db_mindmap.mermaid_code,
61
- parent_file_id=db_mindmap.source_id,
62
- parent_file_key=source.s3_key if source else None,
63
- created_at=db_mindmap.created_at,
64
- message="Mind map generated successfully"
65
- )
66
 
67
- except HTTPException:
68
- raise
69
  except Exception as e:
70
- logger.error(f"Mind map generation endpoint failed: {e}")
71
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  @router.get("/list", response_model=List[MindMapResponse])
74
  async def list_mindmaps(
@@ -83,17 +107,7 @@ async def list_mindmaps(
83
  db_models.MindMap.user_id == current_user.id
84
  ).order_by(db_models.MindMap.created_at.desc()).all()
85
 
86
- return [
87
- MindMapResponse(
88
- id=m.id,
89
- title=m.title,
90
- mermaid_code=m.mermaid_code,
91
- parent_file_id=m.source_id,
92
- parent_file_key=m.source.s3_key if m.source else None,
93
- created_at=m.created_at,
94
- message="Retrieved successfully"
95
- ) for m in mindmaps
96
- ]
97
  except Exception as e:
98
  raise HTTPException(status_code=500, detail=str(e))
99
 
 
1
  import logging
2
+ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
3
  from sqlalchemy.orm import Session
4
  from typing import List
5
+ from datetime import datetime
6
 
7
  from api.auth import get_current_user
8
  from models import db_models
9
  from models.schemas import MindMapGenerateRequest, MindMapResponse
10
+ from core.database import get_db, SessionLocal
11
+ from api.websocket_routes import manager
12
  from services.mindmap_service import mindmap_service
13
 
14
  router = APIRouter(prefix="/api/mindmaps", tags=["mindmaps"])
15
  logger = logging.getLogger(__name__)
16
 
17
+ async def run_mindmap_generation(mindmap_id: int, request: MindMapGenerateRequest, user_id: int):
18
+ """Background task for mind map generation"""
19
+ db = SessionLocal()
20
+ connection_id = f"user_{user_id}"
 
 
 
 
 
21
  try:
22
+ db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first()
23
+ if not db_mindmap: return
 
 
 
 
 
 
 
 
24
 
25
+ # Call AI service
26
  mermaid_code = await mindmap_service.generate_mindmap(
27
  file_key=request.file_key,
28
  text_input=request.text_input
29
  )
30
 
31
  if not mermaid_code:
32
+ raise Exception("AI failed to generate mind map code")
33
 
34
+ db_mindmap.mermaid_code = mermaid_code
35
+ db_mindmap.status = "completed"
 
 
 
 
 
 
 
36
  db.commit()
 
37
 
38
+ # Notify via WebSocket
39
+ await manager.send_result(connection_id, {
40
+ "type": "mindmap",
41
+ "id": db_mindmap.id,
42
+ "status": "completed",
43
+ "title": db_mindmap.title
44
+ })
 
 
45
 
 
 
46
  except Exception as e:
47
+ logger.error(f"Background mindmap generation failed: {e}")
48
+ db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first()
49
+ if db_mindmap:
50
+ db_mindmap.status = "failed"
51
+ db_mindmap.error_message = str(e)
52
+ db.commit()
53
+ await manager.send_error(connection_id, f"Mind map generation failed: {str(e)}")
54
+ finally:
55
+ db.close()
56
+
57
+ @router.post("/generate", response_model=MindMapResponse)
58
+ async def generate_mindmap(
59
+ request: MindMapGenerateRequest,
60
+ background_tasks: BackgroundTasks,
61
+ current_user: db_models.User = Depends(get_current_user),
62
+ db: Session = Depends(get_db)
63
+ ):
64
+ """
65
+ Initiates mind map generation in the background.
66
+ """
67
+ source_id = None
68
+ if request.file_key:
69
+ source = db.query(db_models.Source).filter(
70
+ db_models.Source.s3_key == request.file_key,
71
+ db_models.Source.user_id == current_user.id
72
+ ).first()
73
+ if not source:
74
+ raise HTTPException(status_code=403, detail="Not authorized to access this file")
75
+ source_id = source.id
76
+
77
+ # Create initial processing record
78
+ title = request.title if request.title else (request.file_key.split('/')[-1] if request.file_key else f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}")
79
+ db_mindmap = db_models.MindMap(
80
+ title=title,
81
+ user_id=current_user.id,
82
+ source_id=source_id,
83
+ status="processing"
84
+ )
85
+ db.add(db_mindmap)
86
+ db.commit()
87
+ db.refresh(db_mindmap)
88
+
89
+ # Offload to background task
90
+ background_tasks.add_task(run_mindmap_generation, db_mindmap.id, request, current_user.id)
91
+
92
+ # return processing state
93
+ resp = MindMapResponse.model_validate(db_mindmap)
94
+ resp.message = "Mind map generation started"
95
+ return resp
96
 
97
  @router.get("/list", response_model=List[MindMapResponse])
98
  async def list_mindmaps(
 
107
  db_models.MindMap.user_id == current_user.id
108
  ).order_by(db_models.MindMap.created_at.desc()).all()
109
 
110
+ return [MindMapResponse.model_validate(m) for m in mindmaps]
 
 
 
 
 
 
 
 
 
 
111
  except Exception as e:
112
  raise HTTPException(status_code=500, detail=str(e))
113
 
api/quizzes.py CHANGED
@@ -1,102 +1,125 @@
1
  import logging
2
- from fastapi import APIRouter, Depends, HTTPException
3
  from sqlalchemy.orm import Session
4
- from typing import List
 
 
5
 
6
  from api.auth import get_current_user
7
  from models import db_models
8
  from models.schemas import QuizGenerateRequest, QuizSetResponse
9
- from core.database import get_db
 
10
  from services.quiz_service import quiz_service
11
  from core import constants
12
 
13
  router = APIRouter(prefix="/api/quizzes", tags=["quizzes"])
14
  logger = logging.getLogger(__name__)
15
 
16
- @router.get("/config")
17
- async def get_quiz_config():
18
- """Returns available difficulties, count options, and languages for quizzes."""
19
- return {
20
- "difficulties": constants.DIFFICULTIES,
21
- "counts": constants.QUIZ_COUNTS,
22
- "languages": constants.LANGUAGES
23
- }
24
-
25
- @router.post("/generate", response_model=QuizSetResponse)
26
- async def generate_quiz(
27
- request: QuizGenerateRequest,
28
- current_user: db_models.User = Depends(get_current_user),
29
- db: Session = Depends(get_db)
30
- ):
31
- """
32
- Generates a set of quiz questions and saves them to the database.
33
- """
34
  try:
35
- source_id = None
36
- if request.file_key:
37
- source = db.query(db_models.Source).filter(
38
- db_models.Source.s3_key == request.file_key,
39
- db_models.Source.user_id == current_user.id
40
- ).first()
41
- if not source:
42
- raise HTTPException(status_code=403, detail="Not authorized to access this file")
43
- source_id = source.id
44
-
45
- # 1. Generate Quiz from AI
46
  quizzes_data = await quiz_service.generate_quiz(
47
  file_key=request.file_key,
48
  text_input=request.text_input,
49
  difficulty=request.difficulty,
50
  topic=request.topic,
51
  language=request.language,
52
- count_mode=request.count
 
 
 
53
  )
54
 
55
  if not quizzes_data:
56
- raise HTTPException(status_code=500, detail="Failed to generate quiz")
57
 
58
- # 2. Save Quiz Set
59
- title = request.topic if request.topic else f"Quiz {len(quizzes_data)}"
60
- db_set = db_models.QuizSet(
61
- title=title,
62
- difficulty=request.difficulty,
63
- user_id=current_user.id,
64
- source_id=source_id
65
- )
66
- db.add(db_set)
67
- db.commit()
68
- db.refresh(db_set)
69
-
70
- # 3. Save Questions
71
  for item in quizzes_data:
72
  db_question = db_models.QuizQuestion(
73
  quiz_set_id=db_set.id,
74
  question=item.get("question", ""),
75
  hint=item.get("hint", ""),
76
  choices=item.get("choices", {}),
77
- answer=item.get("answer", "1"),
78
  explanation=item.get("explanation", "")
79
  )
80
  db.add(db_question)
81
 
 
82
  db.commit()
83
- db.refresh(db_set)
84
 
85
- return {
 
 
86
  "id": db_set.id,
87
- "title": db_set.title,
88
- "difficulty": db_set.difficulty,
89
- "created_at": db_set.created_at,
90
- "parent_file_id": db_set.source_id,
91
- "parent_file_key": source.s3_key if source else None,
92
- "questions": db_set.questions
93
- }
94
-
95
- except HTTPException:
96
- raise
97
  except Exception as e:
98
- logger.error(f"Quiz generation endpoint failed: {e}")
99
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  @router.get("/sets", response_model=List[QuizSetResponse])
102
  async def list_quiz_sets(
@@ -110,18 +133,7 @@ async def list_quiz_sets(
110
  sets = db.query(db_models.QuizSet).filter(
111
  db_models.QuizSet.user_id == current_user.id
112
  ).order_by(db_models.QuizSet.created_at.desc()).all()
113
- return [
114
- {
115
- "id": s.id,
116
- "title": s.title,
117
- "difficulty": s.difficulty,
118
- "created_at": s.created_at,
119
- "parent_file_id": s.source_id,
120
- "parent_file_key": s.source.s3_key if s.source else None,
121
- "questions": s.questions
122
- }
123
- for s in sets
124
- ]
125
  except Exception as e:
126
  raise HTTPException(status_code=500, detail=str(e))
127
 
@@ -142,15 +154,7 @@ async def get_quiz_set(
142
  if not db_set:
143
  raise HTTPException(status_code=404, detail="Quiz set not found")
144
 
145
- return {
146
- "id": db_set.id,
147
- "title": db_set.title,
148
- "difficulty": db_set.difficulty,
149
- "created_at": db_set.created_at,
150
- "parent_file_id": db_set.source_id,
151
- "parent_file_key": db_set.source.s3_key if db_set.source else None,
152
- "questions": db_set.questions
153
- }
154
 
155
  @router.delete("/set/{set_id}")
156
  async def delete_quiz_set(
 
1
  import logging
2
+ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
3
  from sqlalchemy.orm import Session
4
+ from typing import List, Dict
5
+ import asyncio
6
+ from datetime import datetime
7
 
8
  from api.auth import get_current_user
9
  from models import db_models
10
  from models.schemas import QuizGenerateRequest, QuizSetResponse
11
+ from core.database import get_db, SessionLocal
12
+ from api.websocket_routes import manager
13
  from services.quiz_service import quiz_service
14
  from core import constants
15
 
16
  router = APIRouter(prefix="/api/quizzes", tags=["quizzes"])
17
  logger = logging.getLogger(__name__)
18
 
19
+ async def run_quiz_generation(set_id: int, request: QuizGenerateRequest, user_id: int):
20
+ """Background task for quiz generation"""
21
+ db = SessionLocal()
22
+ connection_id = f"user_{user_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  try:
24
+ db_set = db.query(db_models.QuizSet).filter(db_models.QuizSet.id == set_id).first()
25
+ if not db_set: return
26
+
27
+ # Call AI service
 
 
 
 
 
 
 
28
  quizzes_data = await quiz_service.generate_quiz(
29
  file_key=request.file_key,
30
  text_input=request.text_input,
31
  difficulty=request.difficulty,
32
  topic=request.topic,
33
  language=request.language,
34
+ count_mode=request.count,
35
+ progress_callback=lambda p, m: asyncio.create_task(
36
+ manager.send_progress(connection_id, p, "processing", m)
37
+ )
38
  )
39
 
40
  if not quizzes_data:
41
+ raise Exception("AI failed to generate quiz questions")
42
 
43
+ # Save individual questions
 
 
 
 
 
 
 
 
 
 
 
 
44
  for item in quizzes_data:
45
  db_question = db_models.QuizQuestion(
46
  quiz_set_id=db_set.id,
47
  question=item.get("question", ""),
48
  hint=item.get("hint", ""),
49
  choices=item.get("choices", {}),
50
+ answer=str(item.get("answer", "1")),
51
  explanation=item.get("explanation", "")
52
  )
53
  db.add(db_question)
54
 
55
+ db_set.status = "completed"
56
  db.commit()
 
57
 
58
+ # Notify via WebSocket
59
+ await manager.send_result(connection_id, {
60
+ "type": "quiz",
61
  "id": db_set.id,
62
+ "status": "completed",
63
+ "title": db_set.title
64
+ })
65
+
 
 
 
 
 
 
66
  except Exception as e:
67
+ logger.error(f"Background quiz generation failed: {e}")
68
+ db_set = db.query(db_models.QuizSet).filter(db_models.QuizSet.id == set_id).first()
69
+ if db_set:
70
+ db_set.status = "failed"
71
+ db_set.error_message = str(e)
72
+ db.commit()
73
+ await manager.send_error(connection_id, f"Quiz generation failed: {str(e)}")
74
+ finally:
75
+ db.close()
76
+
77
+ @router.get("/config")
78
+ async def get_quiz_config():
79
+ """Returns available difficulties, count options, and languages for quizzes."""
80
+ return {
81
+ "difficulties": constants.DIFFICULTIES,
82
+ "counts": constants.QUIZ_COUNTS,
83
+ "languages": constants.LANGUAGES
84
+ }
85
+
86
+ @router.post("/generate", response_model=QuizSetResponse)
87
+ async def generate_quiz(
88
+ request: QuizGenerateRequest,
89
+ background_tasks: BackgroundTasks,
90
+ current_user: db_models.User = Depends(get_current_user),
91
+ db: Session = Depends(get_db)
92
+ ):
93
+ """
94
+ Initiates quiz generation in the background.
95
+ """
96
+ source_id = None
97
+ if request.file_key:
98
+ source = db.query(db_models.Source).filter(
99
+ db_models.Source.s3_key == request.file_key,
100
+ db_models.Source.user_id == current_user.id
101
+ ).first()
102
+ if not source:
103
+ raise HTTPException(status_code=403, detail="Not authorized to access this file")
104
+ source_id = source.id
105
+
106
+ # Create initial processing record
107
+ title = request.topic if request.topic else f"Quiz {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
108
+ db_set = db_models.QuizSet(
109
+ title=title,
110
+ difficulty=request.difficulty,
111
+ user_id=current_user.id,
112
+ source_id=source_id,
113
+ status="processing"
114
+ )
115
+ db.add(db_set)
116
+ db.commit()
117
+ db.refresh(db_set)
118
+
119
+ # Offload to background task
120
+ background_tasks.add_task(run_quiz_generation, db_set.id, request, current_user.id)
121
+
122
+ return db_set
123
 
124
  @router.get("/sets", response_model=List[QuizSetResponse])
125
  async def list_quiz_sets(
 
133
  sets = db.query(db_models.QuizSet).filter(
134
  db_models.QuizSet.user_id == current_user.id
135
  ).order_by(db_models.QuizSet.created_at.desc()).all()
136
+ return [QuizSetResponse.model_validate(s) for s in sets]
 
 
 
 
 
 
 
 
 
 
 
137
  except Exception as e:
138
  raise HTTPException(status_code=500, detail=str(e))
139
 
 
154
  if not db_set:
155
  raise HTTPException(status_code=404, detail="Quiz set not found")
156
 
157
+ return QuizSetResponse.model_validate(db_set)
 
 
 
 
 
 
 
 
158
 
159
  @router.delete("/set/{set_id}")
160
  async def delete_quiz_set(
api/reports.py CHANGED
@@ -1,18 +1,69 @@
1
  import logging
2
- from fastapi import APIRouter, Depends, HTTPException
3
  from sqlalchemy.orm import Session
4
  from typing import List, Optional
 
5
 
6
  from api.auth import get_current_user
7
  from models import db_models
8
  from models.schemas import ReportGenerateRequest, ReportResponse, ReportFormatSuggestionResponse
9
- from core.database import get_db
 
10
  from services.report_service import report_service
11
  from core import constants
12
 
13
  router = APIRouter(prefix="/api/reports", tags=["reports"])
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @router.get("/config")
17
  async def get_report_config():
18
  """Returns available formats and languages for report generation."""
@@ -41,63 +92,40 @@ async def suggest_formats(
41
  @router.post("/generate", response_model=ReportResponse)
42
  async def generate_report(
43
  request: ReportGenerateRequest,
 
44
  current_user: db_models.User = Depends(get_current_user),
45
  db: Session = Depends(get_db)
46
  ):
47
  """
48
- Generates a full report and saves it to the database.
49
  """
50
- try:
51
- source_id = None
52
- if request.file_key:
53
- source = db.query(db_models.Source).filter(
54
- db_models.Source.s3_key == request.file_key,
55
- db_models.Source.user_id == current_user.id
56
- ).first()
57
- if not source:
58
- raise HTTPException(status_code=403, detail="Not authorized to access this file")
59
- source_id = source.id
60
-
61
- # 1. Generate Report from AI
62
- content = await report_service.generate_report(
63
- file_key=request.file_key,
64
- text_input=request.text_input,
65
- format_key=request.format_key,
66
- custom_prompt=request.custom_prompt,
67
- language=request.language
68
- )
69
-
70
- if not content:
71
- raise HTTPException(status_code=500, detail="Failed to generate report")
72
-
73
- # 2. Extract title (usually the first line)
74
- title = content.split('\n')[0].replace('#', '').strip()
75
- if not title or len(title) < 3:
76
- title = f"Report {request.format_key}"
77
 
78
- # 3. Save to DB
79
- db_report = db_models.Report(
80
- title=title,
81
- content=content,
82
- format_key=request.format_key,
83
- user_id=current_user.id,
84
- source_id=source_id
85
- )
86
- db.add(db_report)
87
- db.commit()
88
- db.refresh(db_report)
 
89
 
90
- return {
91
- **db_report.__dict__,
92
- "parent_file_id": db_report.source_id,
93
- "parent_file_key": source.s3_key if source else None
94
- }
95
 
96
- except HTTPException:
97
- raise
98
- except Exception as e:
99
- logger.error(f"Report generation endpoint failed: {e}")
100
- raise HTTPException(status_code=500, detail=str(e))
101
 
102
  @router.get("/list", response_model=List[ReportResponse])
103
  async def list_reports(
@@ -111,18 +139,7 @@ async def list_reports(
111
  reports = db.query(db_models.Report).filter(
112
  db_models.Report.user_id == current_user.id
113
  ).order_by(db_models.Report.created_at.desc()).all()
114
- return [
115
- {
116
- "id": r.id,
117
- "title": r.title,
118
- "content": r.content,
119
- "format_key": r.format_key,
120
- "parent_file_id": r.source_id,
121
- "parent_file_key": r.source.s3_key if r.source else None,
122
- "created_at": r.created_at
123
- }
124
- for r in reports
125
- ]
126
  except Exception as e:
127
  raise HTTPException(status_code=500, detail=str(e))
128
 
@@ -143,15 +160,7 @@ async def get_report(
143
  if not report:
144
  raise HTTPException(status_code=404, detail="Report not found")
145
 
146
- return {
147
- "id": report.id,
148
- "title": report.title,
149
- "content": report.content,
150
- "format_key": report.format_key,
151
- "parent_file_id": report.source_id,
152
- "parent_file_key": report.source.s3_key if report.source else None,
153
- "created_at": report.created_at
154
- }
155
 
156
  @router.delete("/{report_id}")
157
  async def delete_report(
 
1
  import logging
2
+ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
3
  from sqlalchemy.orm import Session
4
  from typing import List, Optional
5
+ from datetime import datetime
6
 
7
  from api.auth import get_current_user
8
  from models import db_models
9
  from models.schemas import ReportGenerateRequest, ReportResponse, ReportFormatSuggestionResponse
10
+ from core.database import get_db, SessionLocal
11
+ from api.websocket_routes import manager
12
  from services.report_service import report_service
13
  from core import constants
14
 
15
  router = APIRouter(prefix="/api/reports", tags=["reports"])
16
  logger = logging.getLogger(__name__)
17
 
18
+ async def run_report_generation(report_id: int, request: ReportGenerateRequest, user_id: int):
19
+ """Background task for report generation"""
20
+ db = SessionLocal()
21
+ connection_id = f"user_{user_id}"
22
+ try:
23
+ db_report = db.query(db_models.Report).filter(db_models.Report.id == report_id).first()
24
+ if not db_report: return
25
+
26
+ # Call AI service
27
+ content = await report_service.generate_report(
28
+ file_key=request.file_key,
29
+ text_input=request.text_input,
30
+ format_key=request.format_key,
31
+ custom_prompt=request.custom_prompt,
32
+ language=request.language
33
+ )
34
+
35
+ if not content:
36
+ raise Exception("AI failed to generate report content")
37
+
38
+ # Extract title (usually the first line)
39
+ title = content.split('\n')[0].replace('#', '').strip()
40
+ if not title or len(title) < 3:
41
+ title = f"Report {request.format_key}"
42
+
43
+ db_report.title = title
44
+ db_report.content = content
45
+ db_report.status = "completed"
46
+ db.commit()
47
+
48
+ # Notify via WebSocket
49
+ await manager.send_result(connection_id, {
50
+ "type": "report",
51
+ "id": db_report.id,
52
+ "status": "completed",
53
+ "title": db_report.title
54
+ })
55
+
56
+ except Exception as e:
57
+ logger.error(f"Background report generation failed: {e}")
58
+ db_report = db.query(db_models.Report).filter(db_models.Report.id == report_id).first()
59
+ if db_report:
60
+ db_report.status = "failed"
61
+ db_report.error_message = str(e)
62
+ db.commit()
63
+ await manager.send_error(connection_id, f"Report generation failed: {str(e)}")
64
+ finally:
65
+ db.close()
66
+
67
  @router.get("/config")
68
  async def get_report_config():
69
  """Returns available formats and languages for report generation."""
 
92
  @router.post("/generate", response_model=ReportResponse)
93
  async def generate_report(
94
  request: ReportGenerateRequest,
95
+ background_tasks: BackgroundTasks,
96
  current_user: db_models.User = Depends(get_current_user),
97
  db: Session = Depends(get_db)
98
  ):
99
  """
100
+ Initiates report generation in the background.
101
  """
102
+ source_id = None
103
+ if request.file_key:
104
+ source = db.query(db_models.Source).filter(
105
+ db_models.Source.s3_key == request.file_key,
106
+ db_models.Source.user_id == current_user.id
107
+ ).first()
108
+ if not source:
109
+ raise HTTPException(status_code=403, detail="Not authorized to access this file")
110
+ source_id = source.id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # Create initial processing record
113
+ title = f"Report {request.format_key} {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
114
+ db_report = db_models.Report(
115
+ title=title,
116
+ format_key=request.format_key,
117
+ user_id=current_user.id,
118
+ source_id=source_id,
119
+ status="processing"
120
+ )
121
+ db.add(db_report)
122
+ db.commit()
123
+ db.refresh(db_report)
124
 
125
+ # Offload to background task
126
+ background_tasks.add_task(run_report_generation, db_report.id, request, current_user.id)
 
 
 
127
 
128
+ return db_report
 
 
 
 
129
 
130
  @router.get("/list", response_model=List[ReportResponse])
131
  async def list_reports(
 
139
  reports = db.query(db_models.Report).filter(
140
  db_models.Report.user_id == current_user.id
141
  ).order_by(db_models.Report.created_at.desc()).all()
142
+ return [ReportResponse.model_validate(r) for r in reports]
 
 
 
 
 
 
 
 
 
 
 
143
  except Exception as e:
144
  raise HTTPException(status_code=500, detail=str(e))
145
 
 
160
  if not report:
161
  raise HTTPException(status_code=404, detail="Report not found")
162
 
163
+ return ReportResponse.model_validate(report)
 
 
 
 
 
 
 
 
164
 
165
  @router.delete("/{report_id}")
166
  async def delete_report(
api/video_generator.py CHANGED
@@ -1,12 +1,14 @@
1
  import logging
2
- from fastapi import APIRouter, Depends, HTTPException
3
  from sqlalchemy.orm import Session
4
  from typing import List
 
5
 
6
  from api.auth import get_current_user
7
  from models import db_models
8
  from models.schemas import VideoSummaryGenerateRequest, VideoSummaryResponse
9
- from core.database import get_db
 
10
  from services.video_generator_service import video_generator_service
11
  from services.slides_video_service import slides_video_service
12
  from services.s3_service import s3_service
@@ -14,27 +16,19 @@ from services.s3_service import s3_service
14
  router = APIRouter(prefix="/api/videos", tags=["video-generator"])
15
  logger = logging.getLogger(__name__)
16
 
17
- @router.post("/generate", response_model=VideoSummaryResponse)
18
- async def generate_video_summary(
19
- request: VideoSummaryGenerateRequest,
20
- current_user: db_models.User = Depends(get_current_user),
21
- db: Session = Depends(get_db)
22
- ):
23
- """
24
- Analyzes a PDF and generates a narrated video summary.
25
- """
26
  try:
27
- # Check source ownership
28
- source = db.query(db_models.Source).filter(
29
- db_models.Source.s3_key == request.file_key,
30
- db_models.Source.user_id == current_user.id
31
- ).first()
32
-
33
- if not source:
34
- raise HTTPException(status_code=403, detail="Not authorized to access this file")
35
 
36
  if request.use_slides_transformation:
37
- # Full PDF -> Slides -> Video pipeline
38
  result = await slides_video_service.generate_transformed_video_summary(
39
  file_key=request.file_key,
40
  language=request.language,
@@ -42,39 +36,74 @@ async def generate_video_summary(
42
  custom_prompt=request.custom_prompt
43
  )
44
  else:
45
- # Standard PDF -> Video pipeline (high fidelity version)
46
  result = await video_generator_service.generate_video_summary(
47
  file_key=request.file_key,
48
  language=request.language,
49
  voice_name=request.voice_name
50
  )
51
 
52
- # Save to DB
53
- db_summary = db_models.VideoSummary(
54
- title=result["title"],
55
- s3_key=result["s3_key"],
56
- s3_url=result["s3_url"],
57
- user_id=current_user.id,
58
- source_id=source.id
59
- )
60
- db.add(db_summary)
61
  db.commit()
62
- db.refresh(db_summary)
63
 
64
- return {
 
 
65
  "id": db_summary.id,
66
- "title": db_summary.title,
67
- "s3_key": db_summary.s3_key,
68
- "public_url": db_summary.s3_url,
69
- "private_url": s3_service.get_presigned_url(db_summary.s3_key),
70
- "parent_file_id": db_summary.source_id,
71
- "parent_file_key": db_summary.source.s3_key if db_summary.source else None,
72
- "created_at": db_summary.created_at
73
- }
74
 
75
  except Exception as e:
76
- logger.error(f"Video summary endpoint failed: {e}")
77
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  @router.get("/list", response_model=List[VideoSummaryResponse])
80
  async def list_video_summaries(
@@ -89,22 +118,30 @@ async def list_video_summaries(
89
  db_models.VideoSummary.user_id == current_user.id
90
  ).order_by(db_models.VideoSummary.created_at.desc()).all()
91
 
92
- return [
93
- {
94
- "id": s.id,
95
- "title": s.title,
96
- "s3_key": s.s3_key,
97
- "public_url": s.s3_url,
98
- "private_url": s3_service.get_presigned_url(s.s3_key),
99
- "parent_file_id": s.source_id,
100
- "parent_file_key": s.source.s3_key if s.source else None,
101
- "created_at": s.created_at
102
- }
103
- for s in summaries
104
- ]
105
  except Exception as e:
 
106
  raise HTTPException(status_code=500, detail=str(e))
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  @router.delete("/{video_id}")
109
  async def delete_video_summary(
110
  video_id: int,
 
1
  import logging
2
+ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
3
  from sqlalchemy.orm import Session
4
  from typing import List
5
+ from datetime import datetime
6
 
7
  from api.auth import get_current_user
8
  from models import db_models
9
  from models.schemas import VideoSummaryGenerateRequest, VideoSummaryResponse
10
+ from core.database import get_db, SessionLocal
11
+ from api.websocket_routes import manager
12
  from services.video_generator_service import video_generator_service
13
  from services.slides_video_service import slides_video_service
14
  from services.s3_service import s3_service
 
16
  router = APIRouter(prefix="/api/videos", tags=["video-generator"])
17
  logger = logging.getLogger(__name__)
18
 
19
+ async def run_video_generation(summary_id: int, request: VideoSummaryGenerateRequest, user_id: int):
20
+ """Background task for video summary generation"""
21
+ logger.info(f"Starting background video generation for ID: {summary_id}")
22
+ db = SessionLocal()
23
+ connection_id = f"user_{user_id}"
 
 
 
 
24
  try:
25
+ db_summary = db.query(db_models.VideoSummary).filter(db_models.VideoSummary.id == summary_id).first()
26
+ if not db_summary:
27
+ logger.error(f"Video summary {summary_id} not found in database")
28
+ return
 
 
 
 
29
 
30
  if request.use_slides_transformation:
31
+ logger.info(f"Task {summary_id}: Using slides transformation pipeline")
32
  result = await slides_video_service.generate_transformed_video_summary(
33
  file_key=request.file_key,
34
  language=request.language,
 
36
  custom_prompt=request.custom_prompt
37
  )
38
  else:
39
+ logger.info(f"Task {summary_id}: Using standard video pipeline")
40
  result = await video_generator_service.generate_video_summary(
41
  file_key=request.file_key,
42
  language=request.language,
43
  voice_name=request.voice_name
44
  )
45
 
46
+ db_summary.title = result["title"]
47
+ db_summary.s3_key = result["s3_key"]
48
+ db_summary.s3_url = result["s3_url"]
49
+ db_summary.status = "completed"
 
 
 
 
 
50
  db.commit()
51
+ logger.info(f"Task {summary_id}: Successfully completed")
52
 
53
+ # Notify via WebSocket
54
+ await manager.send_result(connection_id, {
55
+ "type": "video",
56
  "id": db_summary.id,
57
+ "status": "completed",
58
+ "title": db_summary.title
59
+ })
 
 
 
 
 
60
 
61
  except Exception as e:
62
+ logger.error(f"Task {summary_id}: Background video generation failed: {e}")
63
+ db_summary = db.query(db_models.VideoSummary).filter(db_models.VideoSummary.id == summary_id).first()
64
+ if db_summary:
65
+ db_summary.status = "failed"
66
+ db_summary.error_message = str(e)
67
+ db.commit()
68
+ await manager.send_error(connection_id, f"Video generation failed: {str(e)}")
69
+ finally:
70
+ db.close()
71
+
72
+ @router.post("/generate", response_model=VideoSummaryResponse)
73
+ async def generate_video_summary(
74
+ request: VideoSummaryGenerateRequest,
75
+ background_tasks: BackgroundTasks,
76
+ current_user: db_models.User = Depends(get_current_user),
77
+ db: Session = Depends(get_db)
78
+ ):
79
+ """
80
+ Initiates video summary generation in the background.
81
+ """
82
+ # Check source ownership
83
+ source = db.query(db_models.Source).filter(
84
+ db_models.Source.s3_key == request.file_key,
85
+ db_models.Source.user_id == current_user.id
86
+ ).first()
87
+
88
+ if not source:
89
+ raise HTTPException(status_code=403, detail="Not authorized to access this file")
90
+
91
+ # Create initial processing record
92
+ title = f"Video Summary {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
93
+ db_summary = db_models.VideoSummary(
94
+ title=title,
95
+ user_id=current_user.id,
96
+ source_id=source.id,
97
+ status="processing"
98
+ )
99
+ db.add(db_summary)
100
+ db.commit()
101
+ db.refresh(db_summary)
102
+
103
+ # Offload to background task
104
+ background_tasks.add_task(run_video_generation, db_summary.id, request, current_user.id)
105
+
106
+ return db_summary
107
 
108
  @router.get("/list", response_model=List[VideoSummaryResponse])
109
  async def list_video_summaries(
 
118
  db_models.VideoSummary.user_id == current_user.id
119
  ).order_by(db_models.VideoSummary.created_at.desc()).all()
120
 
121
+ return [VideoSummaryResponse.model_validate(s) for s in summaries]
 
 
 
 
 
 
 
 
 
 
 
 
122
  except Exception as e:
123
+ logger.error(f"Failed to list video summaries: {e}")
124
  raise HTTPException(status_code=500, detail=str(e))
125
 
126
+ @router.get("/{video_id}", response_model=VideoSummaryResponse)
127
+ async def get_video_summary(
128
+ video_id: int,
129
+ current_user: db_models.User = Depends(get_current_user),
130
+ db: Session = Depends(get_db)
131
+ ):
132
+ """
133
+ Retrieves a specific video summary.
134
+ """
135
+ summary = db.query(db_models.VideoSummary).filter(
136
+ db_models.VideoSummary.id == video_id,
137
+ db_models.VideoSummary.user_id == current_user.id
138
+ ).first()
139
+
140
+ if not summary:
141
+ raise HTTPException(status_code=404, detail="Video summary not found")
142
+
143
+ return VideoSummaryResponse.model_validate(summary)
144
+
145
  @router.delete("/{video_id}")
146
  async def delete_video_summary(
147
  video_id: int,
api/websocket_routes.py CHANGED
@@ -15,6 +15,8 @@ from services.mindmap_service import mindmap_service
15
  from services.podcast_service import podcast_service
16
  from services.s3_service import s3_service
17
  from services.video_generator_service import video_generator_service
 
 
18
 
19
  router = APIRouter(prefix="/ws", tags=["websockets"])
20
  logger = logging.getLogger(__name__)
@@ -107,11 +109,14 @@ async def unified_generate_ws(
107
  if task_type == "podcast":
108
  await handle_podcast_task(connection_id, data, current_user, db)
109
 
110
- elif task_type == "flashcards":
111
- await handle_flashcards_task(connection_id, data, current_user, db)
112
 
113
- elif task_type == "quiz":
114
- await handle_quiz_task(connection_id, data, current_user, db)
 
 
 
115
 
116
  else:
117
  await manager.send_error(connection_id, f"Unsupported task type: {task_type}")
@@ -243,6 +248,19 @@ async def handle_flashcards_task(connection_id: str, data: Dict, current_user: d
243
  return
244
  source_id = source.id
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  await manager.send_progress(connection_id, 10, "processing", "Generating flashcards...")
247
 
248
  cards_data = await flashcard_service.generate_flashcards(
@@ -258,22 +276,10 @@ async def handle_flashcards_task(connection_id: str, data: Dict, current_user: d
258
  )
259
 
260
  if not cards_data:
261
- await manager.send_error(connection_id, "AI returned an empty response")
262
- return
263
 
264
  await manager.send_progress(connection_id, 85, "processing", "Saving to database...")
265
 
266
- title = data.get("topic", f"Flashcards {len(cards_data)}")
267
- db_set = db_models.FlashcardSet(
268
- title=title,
269
- difficulty=data.get("difficulty", "medium"),
270
- user_id=current_user.id,
271
- source_id=source_id
272
- )
273
- db.add(db_set)
274
- db.commit()
275
- db.refresh(db_set)
276
-
277
  for item in cards_data:
278
  db_card = db_models.Flashcard(
279
  flashcard_set_id=db_set.id,
@@ -282,16 +288,21 @@ async def handle_flashcards_task(connection_id: str, data: Dict, current_user: d
282
  )
283
  db.add(db_card)
284
 
 
285
  db.commit()
286
- db.refresh(db_set)
287
 
288
  await manager.send_result(connection_id, {
289
  "id": db_set.id,
290
  "title": db_set.title,
291
- "flashcards_count": len(db_set.flashcards)
 
292
  })
293
  except Exception as e:
294
  logger.error(f"Flashcard task failed: {e}")
 
 
 
 
295
  await manager.send_error(connection_id, str(e))
296
 
297
  async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
@@ -308,6 +319,19 @@ async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_mode
308
  return
309
  source_id = source.id
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  await manager.send_progress(connection_id, 10, "processing", "Generating quiz...")
312
 
313
  quizzes_data = await quiz_service.generate_quiz(
@@ -323,32 +347,197 @@ async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_mode
323
  )
324
 
325
  if not quizzes_data:
326
- await manager.send_error(connection_id, "Failed to generate quiz")
327
- return
328
-
329
- db_set = db_models.QuizSet(
330
- title=data.get("topic", "Quiz"),
331
- difficulty=data.get("difficulty", "medium"),
332
- user_id=current_user.id,
333
- source_id=source_id
334
- )
335
- db.add(db_set)
336
- db.commit()
337
- db.refresh(db_set)
338
 
339
  for item in quizzes_data:
340
  db_question = db_models.QuizQuestion(
341
  quiz_set_id=db_set.id,
342
  question=item.get("question", ""),
343
  choices=item.get("choices", {}),
344
- answer=item.get("answer", "1"),
345
  explanation=item.get("explanation", "")
346
  )
347
  db.add(db_question)
348
 
 
349
  db.commit()
350
- await manager.send_result(connection_id, {"id": db_set.id, "title": db_set.title})
351
  except Exception as e:
352
  logger.error(f"Quiz task failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  await manager.send_error(connection_id, str(e))
354
 
 
15
  from services.podcast_service import podcast_service
16
  from services.s3_service import s3_service
17
  from services.video_generator_service import video_generator_service
18
+ from services.slides_video_service import slides_video_service
19
+ from models.schemas import VideoSummaryGenerateRequest, ReportGenerateRequest, MindMapGenerateRequest
20
 
21
  router = APIRouter(prefix="/ws", tags=["websockets"])
22
  logger = logging.getLogger(__name__)
 
109
  if task_type == "podcast":
110
  await handle_podcast_task(connection_id, data, current_user, db)
111
 
112
+ elif task_type == "video":
113
+ await handle_video_task(connection_id, data, current_user, db)
114
 
115
+ elif task_type == "report":
116
+ await handle_report_task(connection_id, data, current_user, db)
117
+
118
+ elif task_type == "mindmap":
119
+ await handle_mindmap_task(connection_id, data, current_user, db)
120
 
121
  else:
122
  await manager.send_error(connection_id, f"Unsupported task type: {task_type}")
 
248
  return
249
  source_id = source.id
250
 
251
+ # Create initial processing record
252
+ title = data.get("topic") if data.get("topic") else f"Flashcards {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
253
+ db_set = db_models.FlashcardSet(
254
+ title=title,
255
+ difficulty=data.get("difficulty", "medium"),
256
+ user_id=current_user.id,
257
+ source_id=source_id,
258
+ status="processing"
259
+ )
260
+ db.add(db_set)
261
+ db.commit()
262
+ db.refresh(db_set)
263
+
264
  await manager.send_progress(connection_id, 10, "processing", "Generating flashcards...")
265
 
266
  cards_data = await flashcard_service.generate_flashcards(
 
276
  )
277
 
278
  if not cards_data:
279
+ raise Exception("AI returned empty flashcard data")
 
280
 
281
  await manager.send_progress(connection_id, 85, "processing", "Saving to database...")
282
 
 
 
 
 
 
 
 
 
 
 
 
283
  for item in cards_data:
284
  db_card = db_models.Flashcard(
285
  flashcard_set_id=db_set.id,
 
288
  )
289
  db.add(db_card)
290
 
291
+ db_set.status = "completed"
292
  db.commit()
 
293
 
294
  await manager.send_result(connection_id, {
295
  "id": db_set.id,
296
  "title": db_set.title,
297
+ "flashcards_count": len(db_set.flashcards),
298
+ "status": "completed"
299
  })
300
  except Exception as e:
301
  logger.error(f"Flashcard task failed: {e}")
302
+ if 'db_set' in locals():
303
+ db_set.status = "failed"
304
+ db_set.error_message = str(e)
305
+ db.commit()
306
  await manager.send_error(connection_id, str(e))
307
 
308
  async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
 
319
  return
320
  source_id = source.id
321
 
322
+ # Create initial processing record
323
+ title = data.get("topic") if data.get("topic") else f"Quiz {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
324
+ db_set = db_models.QuizSet(
325
+ title=title,
326
+ difficulty=data.get("difficulty", "medium"),
327
+ user_id=current_user.id,
328
+ source_id=source_id,
329
+ status="processing"
330
+ )
331
+ db.add(db_set)
332
+ db.commit()
333
+ db.refresh(db_set)
334
+
335
  await manager.send_progress(connection_id, 10, "processing", "Generating quiz...")
336
 
337
  quizzes_data = await quiz_service.generate_quiz(
 
347
  )
348
 
349
  if not quizzes_data:
350
+ raise Exception("AI failed to generate quiz data")
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  for item in quizzes_data:
353
  db_question = db_models.QuizQuestion(
354
  quiz_set_id=db_set.id,
355
  question=item.get("question", ""),
356
  choices=item.get("choices", {}),
357
+ answer=str(item.get("answer", "1")),
358
  explanation=item.get("explanation", "")
359
  )
360
  db.add(db_question)
361
 
362
+ db_set.status = "completed"
363
  db.commit()
364
+ await manager.send_result(connection_id, {"id": db_set.id, "title": db_set.title, "status": "completed"})
365
  except Exception as e:
366
  logger.error(f"Quiz task failed: {e}")
367
+ if 'db_set' in locals():
368
+ db_set.status = "failed"
369
+ db_set.error_message = str(e)
370
+ db.commit()
371
+ await manager.send_error(connection_id, str(e))
372
+
373
+ async def handle_video_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
374
+ """Internal handler for video summary generation"""
375
+ try:
376
+ source = db.query(db_models.Source).filter(
377
+ db_models.Source.s3_key == data.get("file_key"),
378
+ db_models.Source.user_id == current_user.id
379
+ ).first()
380
+ if not source:
381
+ await manager.send_error(connection_id, "Not authorized to access this file")
382
+ return
383
+
384
+ db_summary = db_models.VideoSummary(
385
+ title=f"Video Summary {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
386
+ user_id=current_user.id,
387
+ source_id=source.id,
388
+ status="processing"
389
+ )
390
+ db.add(db_summary)
391
+ db.commit()
392
+ db.refresh(db_summary)
393
+
394
+ await manager.send_progress(connection_id, 10, "processing", "Starting video generation...")
395
+
396
+ if data.get("use_slides_transformation", True):
397
+ result = await slides_video_service.generate_transformed_video_summary(
398
+ file_key=data["file_key"],
399
+ language=data.get("language", "Japanese"),
400
+ voice_name=data.get("voice_name", "Kore"),
401
+ custom_prompt=data.get("custom_prompt", "")
402
+ )
403
+ else:
404
+ result = await video_generator_service.generate_video_summary(
405
+ file_key=data["file_key"],
406
+ language=data.get("language", "Japanese"),
407
+ voice_name=data.get("voice_name", "Kore")
408
+ )
409
+
410
+ db_summary.title = result["title"]
411
+ db_summary.s3_key = result["s3_key"]
412
+ db_summary.s3_url = result["s3_url"]
413
+ db_summary.status = "completed"
414
+ db.commit()
415
+
416
+ await manager.send_result(connection_id, {
417
+ "type": "video",
418
+ "id": db_summary.id,
419
+ "status": "completed",
420
+ "title": db_summary.title,
421
+ "public_url": db_summary.s3_url
422
+ })
423
+ except Exception as e:
424
+ logger.error(f"Video task failed: {e}")
425
+ if 'db_summary' in locals():
426
+ db_summary.status = "failed"
427
+ db_summary.error_message = str(e)
428
+ db.commit()
429
+ await manager.send_error(connection_id, str(e))
430
+
431
+ async def handle_report_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
432
+ """Internal handler for report generation"""
433
+ try:
434
+ source_id = None
435
+ if data.get("file_key"):
436
+ source = db.query(db_models.Source).filter(
437
+ db_models.Source.s3_key == data["file_key"],
438
+ db_models.Source.user_id == current_user.id
439
+ ).first()
440
+ if not source:
441
+ await manager.send_error(connection_id, "Not authorized to access this file")
442
+ return
443
+ source_id = source.id
444
+
445
+ db_report = db_models.Report(
446
+ title=f"Report {data.get('format_key', 'custom')} {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
447
+ format_key=data.get("format_key", "custom"),
448
+ user_id=current_user.id,
449
+ source_id=source_id,
450
+ status="processing"
451
+ )
452
+ db.add(db_report)
453
+ db.commit()
454
+ db.refresh(db_report)
455
+
456
+ await manager.send_progress(connection_id, 15, "processing", "Generating report content...")
457
+
458
+ content = await report_service.generate_report(
459
+ file_key=data.get("file_key"),
460
+ text_input=data.get("text_input"),
461
+ format_key=data.get("format_key", "briefing_doc"),
462
+ custom_prompt=data.get("custom_prompt"),
463
+ language=data.get("language", "Japanese")
464
+ )
465
+
466
+ if not content:
467
+ raise Exception("AI failed to generate report content")
468
+
469
+ title = content.split('\n')[0].replace('#', '').strip()
470
+ if not title or len(title) < 3:
471
+ title = f"Report {data.get('format_key')}"
472
+
473
+ db_report.title = title
474
+ db_report.content = content
475
+ db_report.status = "completed"
476
+ db.commit()
477
+
478
+ await manager.send_result(connection_id, {
479
+ "type": "report",
480
+ "id": db_report.id,
481
+ "status": "completed",
482
+ "title": db_report.title
483
+ })
484
+ except Exception as e:
485
+ logger.error(f"Report task failed: {e}")
486
+ if 'db_report' in locals():
487
+ db_report.status = "failed"
488
+ db_report.error_message = str(e)
489
+ db.commit()
490
+ await manager.send_error(connection_id, str(e))
491
+
492
+ async def handle_mindmap_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
493
+ """Internal handler for mindmap generation"""
494
+ try:
495
+ source_id = None
496
+ if data.get("file_key"):
497
+ source = db.query(db_models.Source).filter(
498
+ db_models.Source.s3_key == data["file_key"],
499
+ db_models.Source.user_id == current_user.id
500
+ ).first()
501
+ if not source:
502
+ await manager.send_error(connection_id, "Not authorized to access this file")
503
+ return
504
+ source_id = source.id
505
+
506
+ db_mindmap = db_models.MindMap(
507
+ title=data.get("title") if data.get("title") else f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
508
+ user_id=current_user.id,
509
+ source_id=source_id,
510
+ status="processing"
511
+ )
512
+ db.add(db_mindmap)
513
+ db.commit()
514
+ db.refresh(db_mindmap)
515
+
516
+ await manager.send_progress(connection_id, 20, "processing", "Generating mind map visualization...")
517
+
518
+ mermaid_code = await mindmap_service.generate_mindmap(
519
+ file_key=data.get("file_key"),
520
+ text_input=data.get("text_input")
521
+ )
522
+
523
+ if not mermaid_code:
524
+ raise Exception("AI failed to generate mind map code")
525
+
526
+ db_mindmap.mermaid_code = mermaid_code
527
+ db_mindmap.status = "completed"
528
+ db.commit()
529
+
530
+ await manager.send_result(connection_id, {
531
+ "type": "mindmap",
532
+ "id": db_mindmap.id,
533
+ "status": "completed",
534
+ "title": db_mindmap.title
535
+ })
536
+ except Exception as e:
537
+ logger.error(f"Mindmap task failed: {e}")
538
+ if 'db_mindmap' in locals():
539
+ db_mindmap.status = "failed"
540
+ db_mindmap.error_message = str(e)
541
+ db.commit()
542
  await manager.send_error(connection_id, str(e))
543
 
models/db_models.py CHANGED
@@ -63,6 +63,15 @@ class Podcast(Base):
63
  def parent_file_key(self):
64
  return self.source.s3_key if self.source else None
65
 
 
 
 
 
 
 
 
 
 
66
 
67
  class FlashcardSet(Base):
68
  __tablename__ = "flashcard_sets"
@@ -72,25 +81,37 @@ class FlashcardSet(Base):
72
  difficulty = Column(String(50))
73
  user_id = Column(Integer, ForeignKey("users.id"))
74
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
 
 
75
  created_at = Column(DateTime(timezone=True), server_default=func.now())
76
 
77
  owner = relationship("User", back_populates="flashcard_sets")
78
  source = relationship("Source", back_populates="flashcard_sets")
79
  flashcards = relationship("Flashcard", back_populates="flashcard_set", cascade="all, delete-orphan")
80
 
 
 
 
 
81
  class MindMap(Base):
82
  __tablename__ = "mind_maps"
83
 
84
  id = Column(Integer, primary_key=True, index=True)
85
  title = Column(Unicode(255))
86
- mermaid_code = Column(UnicodeText, nullable=False)
87
  user_id = Column(Integer, ForeignKey("users.id"))
88
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
 
 
89
  created_at = Column(DateTime(timezone=True), server_default=func.now())
90
 
91
  owner = relationship("User", back_populates="mind_maps")
92
  source = relationship("Source", back_populates="mind_maps")
93
 
 
 
 
 
94
  class QuizSet(Base):
95
  __tablename__ = "quiz_sets"
96
 
@@ -99,12 +120,18 @@ class QuizSet(Base):
99
  difficulty = Column(String(50))
100
  user_id = Column(Integer, ForeignKey("users.id"))
101
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
 
 
102
  created_at = Column(DateTime(timezone=True), server_default=func.now())
103
 
104
  owner = relationship("User", back_populates="quiz_sets")
105
  source = relationship("Source", back_populates="quiz_sets")
106
  questions = relationship("QuizQuestion", back_populates="quiz_set", cascade="all, delete-orphan")
107
 
 
 
 
 
108
  class QuizQuestion(Base):
109
  __tablename__ = "quiz_questions"
110
 
@@ -123,29 +150,50 @@ class Report(Base):
123
 
124
  id = Column(Integer, primary_key=True, index=True)
125
  title = Column(Unicode(255))
126
- content = Column(UnicodeText, nullable=False)
127
  format_key = Column(String(100))
128
  user_id = Column(Integer, ForeignKey("users.id"))
129
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
 
 
130
  created_at = Column(DateTime(timezone=True), server_default=func.now())
131
 
132
  owner = relationship("User", back_populates="reports")
133
  source = relationship("Source", back_populates="reports")
134
 
 
 
 
 
135
  class VideoSummary(Base):
136
  __tablename__ = "video_summaries"
137
 
138
  id = Column(Integer, primary_key=True, index=True)
139
  title = Column(Unicode(255))
140
- s3_key = Column(String(512), nullable=False)
141
- s3_url = Column(String(1024), nullable=False)
142
  user_id = Column(Integer, ForeignKey("users.id"))
143
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
 
 
144
  created_at = Column(DateTime(timezone=True), server_default=func.now())
145
 
146
  owner = relationship("User", back_populates="video_summaries")
147
  source = relationship("Source", back_populates="video_summaries")
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  class Flashcard(Base):
150
  __tablename__ = "flashcards"
151
 
 
63
  def parent_file_key(self):
64
  return self.source.s3_key if self.source else None
65
 
66
+ @property
67
+ def public_url(self):
68
+ return self.s3_url
69
+
70
+ @property
71
+ def private_url(self):
72
+ from services.s3_service import s3_service
73
+ return s3_service.get_presigned_url(self.s3_key) if self.s3_key else None
74
+
75
 
76
  class FlashcardSet(Base):
77
  __tablename__ = "flashcard_sets"
 
81
  difficulty = Column(String(50))
82
  user_id = Column(Integer, ForeignKey("users.id"))
83
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
84
+ status = Column(String(50), default="processing")
85
+ error_message = Column(UnicodeText, nullable=True)
86
  created_at = Column(DateTime(timezone=True), server_default=func.now())
87
 
88
  owner = relationship("User", back_populates="flashcard_sets")
89
  source = relationship("Source", back_populates="flashcard_sets")
90
  flashcards = relationship("Flashcard", back_populates="flashcard_set", cascade="all, delete-orphan")
91
 
92
+ @property
93
+ def parent_file_key(self):
94
+ return self.source.s3_key if self.source else None
95
+
96
  class MindMap(Base):
97
  __tablename__ = "mind_maps"
98
 
99
  id = Column(Integer, primary_key=True, index=True)
100
  title = Column(Unicode(255))
101
+ mermaid_code = Column(UnicodeText, nullable=True)
102
  user_id = Column(Integer, ForeignKey("users.id"))
103
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
104
+ status = Column(String(50), default="processing")
105
+ error_message = Column(UnicodeText, nullable=True)
106
  created_at = Column(DateTime(timezone=True), server_default=func.now())
107
 
108
  owner = relationship("User", back_populates="mind_maps")
109
  source = relationship("Source", back_populates="mind_maps")
110
 
111
+ @property
112
+ def parent_file_key(self):
113
+ return self.source.s3_key if self.source else None
114
+
115
  class QuizSet(Base):
116
  __tablename__ = "quiz_sets"
117
 
 
120
  difficulty = Column(String(50))
121
  user_id = Column(Integer, ForeignKey("users.id"))
122
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
123
+ status = Column(String(50), default="processing")
124
+ error_message = Column(UnicodeText, nullable=True)
125
  created_at = Column(DateTime(timezone=True), server_default=func.now())
126
 
127
  owner = relationship("User", back_populates="quiz_sets")
128
  source = relationship("Source", back_populates="quiz_sets")
129
  questions = relationship("QuizQuestion", back_populates="quiz_set", cascade="all, delete-orphan")
130
 
131
+ @property
132
+ def parent_file_key(self):
133
+ return self.source.s3_key if self.source else None
134
+
135
  class QuizQuestion(Base):
136
  __tablename__ = "quiz_questions"
137
 
 
150
 
151
  id = Column(Integer, primary_key=True, index=True)
152
  title = Column(Unicode(255))
153
+ content = Column(UnicodeText, nullable=True)
154
  format_key = Column(String(100))
155
  user_id = Column(Integer, ForeignKey("users.id"))
156
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
157
+ status = Column(String(50), default="processing")
158
+ error_message = Column(UnicodeText, nullable=True)
159
  created_at = Column(DateTime(timezone=True), server_default=func.now())
160
 
161
  owner = relationship("User", back_populates="reports")
162
  source = relationship("Source", back_populates="reports")
163
 
164
+ @property
165
+ def parent_file_key(self):
166
+ return self.source.s3_key if self.source else None
167
+
168
  class VideoSummary(Base):
169
  __tablename__ = "video_summaries"
170
 
171
  id = Column(Integer, primary_key=True, index=True)
172
  title = Column(Unicode(255))
173
+ s3_key = Column(String(512), nullable=True)
174
+ s3_url = Column(String(1024), nullable=True)
175
  user_id = Column(Integer, ForeignKey("users.id"))
176
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
177
+ status = Column(String(50), default="processing")
178
+ error_message = Column(UnicodeText, nullable=True)
179
  created_at = Column(DateTime(timezone=True), server_default=func.now())
180
 
181
  owner = relationship("User", back_populates="video_summaries")
182
  source = relationship("Source", back_populates="video_summaries")
183
 
184
+ @property
185
+ def parent_file_key(self):
186
+ return self.source.s3_key if self.source else None
187
+
188
+ @property
189
+ def public_url(self):
190
+ return self.s3_url
191
+
192
+ @property
193
+ def private_url(self):
194
+ from services.s3_service import s3_service
195
+ return s3_service.get_presigned_url(self.s3_key) if self.s3_key else None
196
+
197
  class Flashcard(Base):
198
  __tablename__ = "flashcards"
199
 
models/schemas.py CHANGED
@@ -62,7 +62,7 @@ class PodcastResponse(BaseModel):
62
  s3_key: Optional[str]
63
  s3_url: Optional[str]
64
  script: Optional[str]
65
- status: str = "completed"
66
  error_message: Optional[str]
67
  parent_file_id: Optional[int] = None
68
  parent_file_key: Optional[str] = None
@@ -90,14 +90,19 @@ class FlashcardResponse(BaseModel):
90
  question: str
91
  answer: str
92
 
 
 
 
93
  class FlashcardSetResponse(BaseModel):
94
  id: int
95
  title: Optional[str]
96
  difficulty: str
 
 
97
  created_at: datetime
98
  parent_file_id: Optional[int] = None
99
  parent_file_key: Optional[str] = None
100
- flashcards: List[FlashcardResponse]
101
 
102
  class Config:
103
  from_attributes = True
@@ -111,11 +116,16 @@ class MindMapGenerateRequest(BaseModel):
111
  class MindMapResponse(BaseModel):
112
  id: Optional[int] = None
113
  title: str
114
- mermaid_code: str
 
 
115
  parent_file_id: Optional[int] = None
116
  parent_file_key: Optional[str] = None
117
  created_at: Optional[datetime] = None
118
- message: str
 
 
 
119
 
120
  # Quiz Schemas
121
  class QuizGenerateRequest(BaseModel):
@@ -134,14 +144,19 @@ class QuizQuestionResponse(BaseModel):
134
  answer: str
135
  explanation: Optional[str]
136
 
 
 
 
137
  class QuizSetResponse(BaseModel):
138
  id: int
139
  title: Optional[str]
140
  difficulty: str
 
 
141
  created_at: datetime
142
  parent_file_id: Optional[int] = None
143
  parent_file_key: Optional[str] = None
144
- questions: List[QuizQuestionResponse]
145
 
146
  class Config:
147
  from_attributes = True
@@ -165,8 +180,10 @@ class ReportGenerateRequest(BaseModel):
165
  class ReportResponse(BaseModel):
166
  id: int
167
  title: str
168
- content: str
169
  format_key: str
 
 
170
  parent_file_id: Optional[int] = None
171
  parent_file_key: Optional[str] = None
172
  created_at: datetime
@@ -185,9 +202,11 @@ class VideoSummaryGenerateRequest(BaseModel):
185
  class VideoSummaryResponse(BaseModel):
186
  id: int
187
  title: str
188
- s3_key: str
189
- public_url: str
190
  private_url: Optional[str] = None
 
 
191
  parent_file_id: Optional[int] = None
192
  parent_file_key: Optional[str] = None
193
  created_at: datetime
 
62
  s3_key: Optional[str]
63
  s3_url: Optional[str]
64
  script: Optional[str]
65
+ status: Optional[str] = "completed"
66
  error_message: Optional[str]
67
  parent_file_id: Optional[int] = None
68
  parent_file_key: Optional[str] = None
 
90
  question: str
91
  answer: str
92
 
93
+ class Config:
94
+ from_attributes = True
95
+
96
  class FlashcardSetResponse(BaseModel):
97
  id: int
98
  title: Optional[str]
99
  difficulty: str
100
+ status: Optional[str] = "completed"
101
+ error_message: Optional[str] = None
102
  created_at: datetime
103
  parent_file_id: Optional[int] = None
104
  parent_file_key: Optional[str] = None
105
+ flashcards: List[FlashcardResponse] = []
106
 
107
  class Config:
108
  from_attributes = True
 
116
  class MindMapResponse(BaseModel):
117
  id: Optional[int] = None
118
  title: str
119
+ mermaid_code: Optional[str] = None
120
+ status: Optional[str] = "completed"
121
+ error_message: Optional[str] = None
122
  parent_file_id: Optional[int] = None
123
  parent_file_key: Optional[str] = None
124
  created_at: Optional[datetime] = None
125
+ message: Optional[str] = None
126
+
127
+ class Config:
128
+ from_attributes = True
129
 
130
  # Quiz Schemas
131
  class QuizGenerateRequest(BaseModel):
 
144
  answer: str
145
  explanation: Optional[str]
146
 
147
+ class Config:
148
+ from_attributes = True
149
+
150
  class QuizSetResponse(BaseModel):
151
  id: int
152
  title: Optional[str]
153
  difficulty: str
154
+ status: Optional[str] = "completed"
155
+ error_message: Optional[str] = None
156
  created_at: datetime
157
  parent_file_id: Optional[int] = None
158
  parent_file_key: Optional[str] = None
159
+ questions: List[QuizQuestionResponse] = []
160
 
161
  class Config:
162
  from_attributes = True
 
180
  class ReportResponse(BaseModel):
181
  id: int
182
  title: str
183
+ content: Optional[str] = None
184
  format_key: str
185
+ status: Optional[str] = "completed"
186
+ error_message: Optional[str] = None
187
  parent_file_id: Optional[int] = None
188
  parent_file_key: Optional[str] = None
189
  created_at: datetime
 
202
  class VideoSummaryResponse(BaseModel):
203
  id: int
204
  title: str
205
+ s3_key: Optional[str] = None
206
+ public_url: Optional[str] = None
207
  private_url: Optional[str] = None
208
+ status: Optional[str] = "completed"
209
+ error_message: Optional[str] = None
210
  parent_file_id: Optional[int] = None
211
  parent_file_key: Optional[str] = None
212
  created_at: datetime
services/slides_video_service.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
 
2
  from typing import Dict, List, Optional, Any, Tuple
3
- import logging
4
  import os
5
  import tempfile
6
  import time
@@ -10,7 +10,7 @@ import re
10
  import wave
11
  import yaml
12
  import requests
13
- import openai
14
  from google.cloud import storage
15
  from googleapiclient.discovery import build
16
  from googleapiclient.http import MediaIoBaseUpload
@@ -31,8 +31,6 @@ from core.prompts import (
31
  )
32
  from services.s3_service import s3_service
33
 
34
- logger = logging.getLogger(__name__)
35
-
36
  # Constants from temp project
37
  TEMPLATE_HINT: Dict[str, str] = {
38
  "cover": "COVER.MAIN",
@@ -49,10 +47,10 @@ TEMPLATE_HINT: Dict[str, str] = {
49
 
50
  class SlidesVideoService:
51
  def __init__(self):
52
- self.openai_client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
53
 
54
  # Match Temp project: Use API Key for Gemini TTS
55
- logger.info("Initializing Gemini Client with API Key for Slides (as in Temp project)")
56
  self.gemini_client = genai.Client(api_key=settings.GEMINI_API_KEY)
57
 
58
  self.scopes = [
@@ -61,18 +59,11 @@ class SlidesVideoService:
61
  ]
62
 
63
  def _get_sa_info(self) -> Optional[Dict[str, Any]]:
64
- """Parse GCP_SA_JSON - matches original Temp project logic exactly."""
65
  sa_json = os.environ.get("GCP_SA_JSON") or os.environ.get("GCS_SA_JSON")
66
- if not sa_json:
67
- return None
68
- # Just parse it directly like the original
69
  return json.loads(sa_json)
70
 
71
  def _get_google_creds(self):
72
- """
73
- Builds Google credentials from environment variables.
74
- Matches Temp project logic.
75
- """
76
  token_json = settings.GOOGLE_OAUTH_TOKEN_JSON
77
  if token_json:
78
  creds = Credentials.from_authorized_user_info(json.loads(token_json), self.scopes)
@@ -83,8 +74,7 @@ class SlidesVideoService:
83
  info = self._get_sa_info()
84
  if info:
85
  return service_account.Credentials.from_service_account_info(info, scopes=self.scopes)
86
-
87
- raise RuntimeError("Google API credentials not configured (GOOGLE_OAUTH_TOKEN_JSON or GCP_SA_JSON required)")
88
 
89
  def _get_clients(self):
90
  creds = self._get_google_creds()
@@ -93,12 +83,18 @@ class SlidesVideoService:
93
  return slides, drive
94
 
95
  async def extract_text_from_pdf(self, pdf_path: str) -> str:
96
- """Extract text from PDF using OpenAI."""
 
97
  with open(pdf_path, "rb") as f:
98
- openai_file = self.openai_client.files.create(file=f, purpose="assistants")
 
 
 
 
 
99
 
100
  prompt = get_pdf_text_extraction_prompt()
101
- response = self.openai_client.chat.completions.create(
102
  model="gpt-4o-mini",
103
  messages=[
104
  {
@@ -109,14 +105,14 @@ class SlidesVideoService:
109
  temperature=0
110
  )
111
  text = response.choices[0].message.content
112
- self.openai_client.files.delete(openai_file.id)
113
  return text
114
 
115
  async def generate_outline(self, source_text: str, language: str = "Japanese", custom_prompt: str = "") -> Dict[str, Any]:
116
- """Step 1: Generate Slide Outline (JSON) from text."""
 
117
  template_path = "core/templates/ja_slide_template.yaml" if language == "Japanese" else "core/templates/eng_slide_template.yaml"
118
  if not os.path.exists(template_path):
119
- # Fallback if I missed copying
120
  template_path = f"Temp/AI-Video-Summary-Generator/{'ja' if language == 'Japanese' else 'eng'}_slide_template.yaml"
121
 
122
  with open(template_path, "r", encoding="utf-8") as f:
@@ -124,7 +120,7 @@ class SlidesVideoService:
124
 
125
  prompt = get_outline_prompt(template_yaml, source_text, custom_prompt, language)
126
 
127
- response = self.openai_client.chat.completions.create(
128
  model="gpt-4o-mini",
129
  messages=[{"role": "user", "content": prompt}],
130
  temperature=0.2,
@@ -133,55 +129,39 @@ class SlidesVideoService:
133
  return json.loads(response.choices[0].message.content)
134
 
135
  async def create_slides_and_export_pdf(self, outline: Dict[str, Any], template_filename: str = "slide_template_v001.pptx") -> bytes:
136
- """Step 2 & 3: Create Google Slides and export to PDF."""
137
- slides_api, drive_api = self._get_clients()
138
-
139
- # 1. Get Template: Try local first, then GCS
140
- pptx_path = os.path.join("core", "templates", template_filename)
141
- if os.path.exists(pptx_path):
142
- with open(pptx_path, "rb") as f:
143
- pptx_bytes = f.read()
144
- else:
145
- logger.info(f"Template {template_filename} not found locally, trying GCS...")
146
- try:
147
- pptx_bytes = self._download_template_from_gcs(template_filename)
148
- except Exception as e:
149
- raise FileNotFoundError(f"Template {template_filename} not found locally or on GCS: {e}")
150
-
151
- # 2. Upload and convert
152
- media = MediaIoBaseUpload(io.BytesIO(pptx_bytes), mimetype="application/vnd.openxmlformats-officedocument.presentationml.presentation")
153
- body = {
154
- "name": f"Generated Video Source {int(time.time())}",
155
- "mimeType": "application/vnd.google-apps.presentation",
156
- }
157
-
158
- folder_id = os.environ.get("DRIVE_FOLDER_ID")
159
- if folder_id:
160
- body["parents"] = [folder_id]
161
 
162
- created = drive_api.files().create(body=body, media_body=media, supportsAllDrives=True, fields="id").execute()
163
- pres_id = created["id"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- try:
166
- # 3. Build slides from outline
167
- self._build_from_outline(slides_api, pres_id, outline)
168
-
169
- # 4. Export to PDF
170
- pdf_bytes = drive_api.files().export(
171
- fileId=pres_id,
172
- mimeType="application/pdf",
173
- ).execute()
174
-
175
- return pdf_bytes
176
- finally:
177
- # Cleanup temp presentation
178
  try:
179
- drive_api.files().delete(fileId=pres_id).execute()
180
- except:
181
- pass
 
 
 
 
182
 
183
  def _build_from_outline(self, slides, pres_id, outline):
184
- """Port of build_from_outline from temp project."""
185
  items = outline.get("slides", [])
186
  initial = slides.presentations().get(presentationId=pres_id).execute()
187
  original_page_ids = [p["objectId"] for p in initial.get("slides", [])]
@@ -189,33 +169,23 @@ class SlidesVideoService:
189
  for item in items:
190
  tpl = item.get("template", "")
191
  fields = item.get("fields", {})
192
-
193
- # Find base page
194
  rep_key = TEMPLATE_HINT.get(tpl) or next(iter(fields.keys()), "")
195
  base_page = self._find_page(slides, pres_id, rep_key)
196
  if not base_page: continue
197
 
198
- # Duplicate
199
  resp = slides.presentations().batchUpdate(
200
  presentationId=pres_id,
201
  body={"requests": [{"duplicateObject": {"objectId": base_page}}]}
202
  ).execute()
203
  new_page = resp["replies"][0]["duplicateObject"]["objectId"]
204
 
205
- # Move to end
206
  pres_detail = slides.presentations().get(presentationId=pres_id).execute()
207
  insertion_index = max(0, len(pres_detail.get("slides", [])) - 1)
208
  slides.presentations().batchUpdate(
209
  presentationId=pres_id,
210
- body={"requests": [{
211
- "updateSlidesPosition": {
212
- "slideObjectIds": [new_page],
213
- "insertionIndex": insertion_index
214
- }
215
- }]}
216
  ).execute()
217
 
218
- # Replace text
219
  reqs = []
220
  for k, v in fields.items():
221
  reqs.append({
@@ -228,39 +198,25 @@ class SlidesVideoService:
228
  if reqs:
229
  slides.presentations().batchUpdate(presentationId=pres_id, body={"requests": reqs}).execute()
230
 
231
- # Cleanup unused placeholders {{...}} on this slide (Matches original implementation)
232
- try:
233
- self._cleanup_placeholders(slides, pres_id, new_page, fields)
234
- except Exception as e:
235
- logger.warning(f"Placeholder cleanup failed for slide {new_page}: {e}")
236
 
237
- # Delete originals
238
  if original_page_ids:
239
  reqs = [{"deleteObject": {"objectId": pid}} for pid in original_page_ids]
240
  slides.presentations().batchUpdate(presentationId=pres_id, body={"requests": reqs}).execute()
241
 
242
  def _cleanup_placeholders(self, slides, pres_id, page_id, fields):
243
- """Finds all remaining {{TAGS}} and replaces them with empty strings."""
244
  pres = slides.presentations().get(presentationId=pres_id).execute()
245
  slide = next(s for s in pres.get("slides", []) if s.get("objectId") == page_id)
246
-
247
  found_tags = set()
248
  for el in slide.get("pageElements", []):
249
  text = el.get("shape", {}).get("text", {})
250
  for te in text.get("textElements", []):
251
  content = te.get("textRun", {}).get("content", "")
252
- for m in re.findall(r"\{\{([A-Z0-9_.-]+)\}\}", content):
253
- found_tags.add(m)
254
-
255
  unused = [t for t in found_tags if t not in fields]
256
  if unused:
257
- reqs = [{
258
- "replaceAllText": {
259
- "containsText": {"text": f"{{{{{t}}}}}", "matchCase": True},
260
- "replaceText": "",
261
- "pageObjectIds": [page_id]
262
- }
263
- } for t in unused]
264
  slides.presentations().batchUpdate(presentationId=pres_id, body={"requests": reqs}).execute()
265
 
266
  def _find_page(self, slides, pres_id, placeholder_key):
@@ -270,59 +226,50 @@ class SlidesVideoService:
270
  for el in page.get("pageElements", []):
271
  text = el.get("shape", {}).get("text", {})
272
  for te in text.get("textElements", []):
273
- if needle in te.get("textRun", {}).get("content", ""):
274
- return page["objectId"]
275
  return None
276
 
277
  def _download_template_from_gcs(self, filename: str) -> bytes:
278
- """Download template from GCS bucket (mimics Temp project logic)."""
279
  bucket_name = settings.GCS_BUCKET
280
- if not bucket_name:
281
- raise RuntimeError("GCS_BUCKET environment variable is missing")
282
-
283
- # Path in bucket from Temp project: templates/filename
284
  object_name = f"templates/{filename}"
285
-
286
- # Use SA if available, else default
287
  info = self._get_sa_info()
288
  if info:
289
  creds = service_account.Credentials.from_service_account_info(info)
290
  client = storage.Client(project=info.get("project_id"), credentials=creds)
291
  else:
292
  client = storage.Client()
293
-
294
  bucket = client.bucket(bucket_name)
295
  blob = bucket.blob(object_name)
296
  return blob.download_as_bytes()
297
 
298
- async def generate_video_from_pdf_bytes(
299
- self,
300
- pdf_bytes: bytes,
301
- language: str = "Japanese",
302
- voice_name: str = "Kore"
303
- ) -> Dict[str, Any]:
304
- """Step 4, 5, 6: PDF bytes -> Video Pipeline."""
305
  temp_dir = tempfile.mkdtemp(prefix="video_final_")
306
  try:
307
  pdf_path = os.path.join(temp_dir, "source.pdf")
308
- with open(pdf_path, "wb") as f:
309
- f.write(pdf_bytes)
310
 
311
  # 1. Images
312
- images = convert_from_path(pdf_path, dpi=200)
 
313
  total_pages = len(images)
314
  image_paths = []
315
  for i, img in enumerate(images, start=1):
316
- p = os.path.join(temp_dir, f"p_{i:02d}.png")
317
- img.save(p, "PNG")
318
- image_paths.append(p)
319
 
320
  # 2. Narration Script
 
321
  with open(pdf_path, "rb") as f:
322
- openai_file = self.openai_client.files.create(file=f, purpose="assistants")
 
 
 
 
 
323
 
324
  prompt = get_video_script_prompt(language, total_pages)
325
- resp = self.openai_client.chat.completions.create(
326
  model="gpt-4o-mini",
327
  messages=[{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "file", "file": {"file_id": openai_file.id}}]}],
328
  response_format={"type": "json_object"},
@@ -330,133 +277,76 @@ class SlidesVideoService:
330
  )
331
  script_data = json.loads(resp.choices[0].message.content)
332
  scripts = script_data.get("scripts", [])
333
- self.openai_client.files.delete(openai_file.id)
334
 
335
- # 3. Audio & Video assembly (similar to existing logic but more refined)
336
  page_clips = []
337
  target_size = (1920, 1080)
338
-
339
  for i, img_path in enumerate(image_paths):
340
- # Skip last slide narration if it's the logo slide (standard logic in temp project)
341
  if i < len(scripts) and i < len(image_paths) - 1:
342
  text = scripts[i].get("script_text", "")
343
  audio_path = os.path.join(temp_dir, f"a_{i}.wav")
344
-
345
- # TTS with fallback
346
  try:
347
- model_name = "gemini-2.5-flash-preview-tts"
348
- logger.info(f"Generating audio for slide {i} using {model_name}...")
349
- tts_resp = self.gemini_client.models.generate_content(
350
- model=model_name,
351
  contents=text,
352
  config=types.GenerateContentConfig(
353
  response_modalities=["AUDIO"],
354
- speech_config=types.SpeechConfig(
355
- voice_config=types.VoiceConfig(
356
- prebuilt_voice_config=types.PrebuiltVoiceConfig(
357
- voice_name=voice_name
358
- )
359
- )
360
- )
361
  )
362
  )
363
- except Exception as tts_err:
364
- logger.warning(f"Failed with {model_name}, trying fallback gemini-1.5-flash: {tts_err}")
365
- model_name = "gemini-1.5-flash"
366
- tts_resp = self.gemini_client.models.generate_content(
367
- model=model_name,
368
  contents=text,
369
  config=types.GenerateContentConfig(
370
  response_modalities=["AUDIO"],
371
- speech_config=types.SpeechConfig(
372
- voice_config=types.VoiceConfig(
373
- prebuilt_voice_config=types.PrebuiltVoiceConfig(
374
- voice_name=voice_name
375
- )
376
- )
377
- )
378
  )
379
  )
380
  audio_data = tts_resp.candidates[0].content.parts[0].inline_data.data
381
- with wave.open(audio_path, "wb") as wf:
382
- wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(24000); wf.writeframes(audio_data)
383
 
384
  aud_clip = AudioFileClip(audio_path)
385
- duration = aud_clip.duration
386
- img_clip = ImageClip(self._prepare_img(img_path, target_size, temp_dir, i), duration=duration)
387
  page_clips.append(img_clip.with_audio(aud_clip))
388
- time.sleep(2)
389
  else:
390
- # Silent 3s for last slide or missing scripts
391
- img_clip = ImageClip(self._prepare_img(img_path, target_size, temp_dir, i), duration=3.0)
392
- page_clips.append(img_clip)
393
 
394
  final_path = os.path.join(temp_dir, "output.mp4")
395
  final_clip = concatenate_videoclips(page_clips, method="compose")
396
- final_clip.write_videofile(final_path, fps=24, codec="libx264", audio_codec="aac", logger=None)
 
397
 
398
- # Cleanup clips
399
  for c in page_clips: c.close()
400
  final_clip.close()
401
 
402
- # Upload to S3
403
- ts = int(time.time())
404
- s3_key = f"users/video_summaries/{ts}_summary.mp4"
405
- s3_service.s3_client.upload_file(final_path, settings.AWS_S3_BUCKET, s3_key)
406
- s3_url = f"https://{settings.AWS_S3_BUCKET}.s3.{settings.AWS_REGION}.amazonaws.com/{s3_key}"
407
-
408
- return {"s3_key": s3_key, "s3_url": s3_url}
409
-
410
  finally:
411
  shutil.rmtree(temp_dir, ignore_errors=True)
412
 
413
  def _prepare_img(self, path, size, temp_dir, idx):
414
- img = Image.open(path)
415
- img.thumbnail(size, Image.Resampling.LANCZOS)
416
- new_img = Image.new("RGB", size, (0, 0, 0))
417
- new_img.paste(img, ((size[0] - img.size[0]) // 2, (size[1] - img.size[1]) // 2))
418
- res_path = os.path.join(temp_dir, f"ready_{idx}.png")
419
- new_img.save(res_path)
420
- return res_path
421
-
422
- async def generate_transformed_video_summary(
423
- self,
424
- file_key: str,
425
- language: str = "Japanese",
426
- voice_name: str = "Kore",
427
- custom_prompt: str = ""
428
- ) -> Dict[str, Any]:
429
- """
430
- The Full Transformation Workflow: PDF -> Text -> Outline -> Slides -> PDF -> Video.
431
- """
432
  temp_dir = tempfile.mkdtemp(prefix="trans_video_")
433
  try:
434
- # 1. Download original PDF
435
  pdf_path = os.path.join(temp_dir, "input.pdf")
436
- s3_service.s3_client.download_file(settings.AWS_S3_BUCKET, file_key, pdf_path)
437
-
438
- # 2. Extract Text
439
- logger.info("Extracting text from PDF...")
440
  source_text = await self.extract_text_from_pdf(pdf_path)
441
-
442
- # 3. Generate Outline
443
- logger.info("Generating slide outline...")
444
  outline = await self.generate_outline(source_text, language, custom_prompt)
445
-
446
- # 4. Create Slides and Export back to PDF (The Transformation)
447
- logger.info("Building Google Slides and exporting...")
448
  transformed_pdf_bytes = await self.create_slides_and_export_pdf(outline)
449
-
450
- # 5. Generate Video from the Transformed PDF
451
- logger.info("Generating video from transformed slides...")
452
  result = await self.generate_video_from_pdf_bytes(transformed_pdf_bytes, language, voice_name)
453
-
454
- return {
455
- "title": f"Transformed Summary - {os.path.basename(file_key)}",
456
- "s3_key": result["s3_key"],
457
- "s3_url": result["s3_url"]
458
- }
459
-
460
  finally:
461
  shutil.rmtree(temp_dir, ignore_errors=True)
462
 
 
1
  import json
2
+ import asyncio
3
  from typing import Dict, List, Optional, Any, Tuple
 
4
  import os
5
  import tempfile
6
  import time
 
10
  import wave
11
  import yaml
12
  import requests
13
+ from openai import AsyncOpenAI
14
  from google.cloud import storage
15
  from googleapiclient.discovery import build
16
  from googleapiclient.http import MediaIoBaseUpload
 
31
  )
32
  from services.s3_service import s3_service
33
 
 
 
34
  # Constants from temp project
35
  TEMPLATE_HINT: Dict[str, str] = {
36
  "cover": "COVER.MAIN",
 
47
 
48
  class SlidesVideoService:
49
  def __init__(self):
50
+ self.openai_client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
51
 
52
  # Match Temp project: Use API Key for Gemini TTS
53
+ print("[SlidesVideo] Initializing Gemini Client with API Key")
54
  self.gemini_client = genai.Client(api_key=settings.GEMINI_API_KEY)
55
 
56
  self.scopes = [
 
59
  ]
60
 
61
  def _get_sa_info(self) -> Optional[Dict[str, Any]]:
 
62
  sa_json = os.environ.get("GCP_SA_JSON") or os.environ.get("GCS_SA_JSON")
63
+ if not sa_json: return None
 
 
64
  return json.loads(sa_json)
65
 
66
  def _get_google_creds(self):
 
 
 
 
67
  token_json = settings.GOOGLE_OAUTH_TOKEN_JSON
68
  if token_json:
69
  creds = Credentials.from_authorized_user_info(json.loads(token_json), self.scopes)
 
74
  info = self._get_sa_info()
75
  if info:
76
  return service_account.Credentials.from_service_account_info(info, scopes=self.scopes)
77
+ raise RuntimeError("Google API credentials not configured")
 
78
 
79
  def _get_clients(self):
80
  creds = self._get_google_creds()
 
83
  return slides, drive
84
 
85
  async def extract_text_from_pdf(self, pdf_path: str) -> str:
86
+ """Extract text from PDF using Native Async OpenAI."""
87
+ print("[SlidesVideo] Extracting text from PDF via OpenAI...")
88
  with open(pdf_path, "rb") as f:
89
+ content = f.read()
90
+
91
+ openai_file = await self.openai_client.files.create(
92
+ file=("source.pdf", content),
93
+ purpose="assistants"
94
+ )
95
 
96
  prompt = get_pdf_text_extraction_prompt()
97
+ response = await self.openai_client.chat.completions.create(
98
  model="gpt-4o-mini",
99
  messages=[
100
  {
 
105
  temperature=0
106
  )
107
  text = response.choices[0].message.content
108
+ await self.openai_client.files.delete(openai_file.id)
109
  return text
110
 
111
  async def generate_outline(self, source_text: str, language: str = "Japanese", custom_prompt: str = "") -> Dict[str, Any]:
112
+ """Step 1: Generate Slide Outline (JSON) from text via Native Async OpenAI."""
113
+ print("[SlidesVideo] Generating slide outline...")
114
  template_path = "core/templates/ja_slide_template.yaml" if language == "Japanese" else "core/templates/eng_slide_template.yaml"
115
  if not os.path.exists(template_path):
 
116
  template_path = f"Temp/AI-Video-Summary-Generator/{'ja' if language == 'Japanese' else 'eng'}_slide_template.yaml"
117
 
118
  with open(template_path, "r", encoding="utf-8") as f:
 
120
 
121
  prompt = get_outline_prompt(template_yaml, source_text, custom_prompt, language)
122
 
123
+ response = await self.openai_client.chat.completions.create(
124
  model="gpt-4o-mini",
125
  messages=[{"role": "user", "content": prompt}],
126
  temperature=0.2,
 
129
  return json.loads(response.choices[0].message.content)
130
 
131
  async def create_slides_and_export_pdf(self, outline: Dict[str, Any], template_filename: str = "slide_template_v001.pptx") -> bytes:
132
+ """Step 2 & 3: Create Google Slides (Sync inside to_thread)."""
133
+ def _execute():
134
+ print("[SlidesVideo] Interacting with Google Slides API...")
135
+ slides_api, drive_api = self._get_clients()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ pptx_path = os.path.join("core", "templates", template_filename)
138
+ if os.path.exists(pptx_path):
139
+ with open(pptx_path, "rb") as f:
140
+ pptx_bytes = f.read()
141
+ else:
142
+ pptx_bytes = self._download_template_from_gcs(template_filename)
143
+
144
+ media = MediaIoBaseUpload(io.BytesIO(pptx_bytes), mimetype="application/vnd.openxmlformats-officedocument.presentationml.presentation")
145
+ body = {
146
+ "name": f"Generated Video Source {int(time.time())}",
147
+ "mimeType": "application/vnd.google-apps.presentation",
148
+ }
149
+ folder_id = os.environ.get("DRIVE_FOLDER_ID")
150
+ if folder_id: body["parents"] = [folder_id]
151
+
152
+ created = drive_api.files().create(body=body, media_body=media, supportsAllDrives=True, fields="id").execute()
153
+ pres_id = created["id"]
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  try:
156
+ self._build_from_outline(slides_api, pres_id, outline)
157
+ pdf_bytes = drive_api.files().export(fileId=pres_id, mimeType="application/pdf").execute()
158
+ return pdf_bytes
159
+ finally:
160
+ try: drive_api.files().delete(fileId=pres_id).execute()
161
+ except: pass
162
+ return await asyncio.to_thread(_execute)
163
 
164
  def _build_from_outline(self, slides, pres_id, outline):
 
165
  items = outline.get("slides", [])
166
  initial = slides.presentations().get(presentationId=pres_id).execute()
167
  original_page_ids = [p["objectId"] for p in initial.get("slides", [])]
 
169
  for item in items:
170
  tpl = item.get("template", "")
171
  fields = item.get("fields", {})
 
 
172
  rep_key = TEMPLATE_HINT.get(tpl) or next(iter(fields.keys()), "")
173
  base_page = self._find_page(slides, pres_id, rep_key)
174
  if not base_page: continue
175
 
 
176
  resp = slides.presentations().batchUpdate(
177
  presentationId=pres_id,
178
  body={"requests": [{"duplicateObject": {"objectId": base_page}}]}
179
  ).execute()
180
  new_page = resp["replies"][0]["duplicateObject"]["objectId"]
181
 
 
182
  pres_detail = slides.presentations().get(presentationId=pres_id).execute()
183
  insertion_index = max(0, len(pres_detail.get("slides", [])) - 1)
184
  slides.presentations().batchUpdate(
185
  presentationId=pres_id,
186
+ body={"requests": [{"updateSlidesPosition": {"slideObjectIds": [new_page], "insertionIndex": insertion_index}}]}
 
 
 
 
 
187
  ).execute()
188
 
 
189
  reqs = []
190
  for k, v in fields.items():
191
  reqs.append({
 
198
  if reqs:
199
  slides.presentations().batchUpdate(presentationId=pres_id, body={"requests": reqs}).execute()
200
 
201
+ try: self._cleanup_placeholders(slides, pres_id, new_page, fields)
202
+ except: pass
 
 
 
203
 
 
204
  if original_page_ids:
205
  reqs = [{"deleteObject": {"objectId": pid}} for pid in original_page_ids]
206
  slides.presentations().batchUpdate(presentationId=pres_id, body={"requests": reqs}).execute()
207
 
208
  def _cleanup_placeholders(self, slides, pres_id, page_id, fields):
 
209
  pres = slides.presentations().get(presentationId=pres_id).execute()
210
  slide = next(s for s in pres.get("slides", []) if s.get("objectId") == page_id)
 
211
  found_tags = set()
212
  for el in slide.get("pageElements", []):
213
  text = el.get("shape", {}).get("text", {})
214
  for te in text.get("textElements", []):
215
  content = te.get("textRun", {}).get("content", "")
216
+ for m in re.findall(r"\{\{([A-Z0-9_.-]+)\}\}", content): found_tags.add(m)
 
 
217
  unused = [t for t in found_tags if t not in fields]
218
  if unused:
219
+ reqs = [{"replaceAllText": {"containsText": {"text": f"{{{{{t}}}}}", "matchCase": True}, "replaceText": "", "pageObjectIds": [page_id]}} for t in unused]
 
 
 
 
 
 
220
  slides.presentations().batchUpdate(presentationId=pres_id, body={"requests": reqs}).execute()
221
 
222
  def _find_page(self, slides, pres_id, placeholder_key):
 
226
  for el in page.get("pageElements", []):
227
  text = el.get("shape", {}).get("text", {})
228
  for te in text.get("textElements", []):
229
+ if needle in te.get("textRun", {}).get("content", ""): return page["objectId"]
 
230
  return None
231
 
232
  def _download_template_from_gcs(self, filename: str) -> bytes:
 
233
  bucket_name = settings.GCS_BUCKET
234
+ if not bucket_name: raise RuntimeError("GCS_BUCKET missing")
 
 
 
235
  object_name = f"templates/{filename}"
 
 
236
  info = self._get_sa_info()
237
  if info:
238
  creds = service_account.Credentials.from_service_account_info(info)
239
  client = storage.Client(project=info.get("project_id"), credentials=creds)
240
  else:
241
  client = storage.Client()
 
242
  bucket = client.bucket(bucket_name)
243
  blob = bucket.blob(object_name)
244
  return blob.download_as_bytes()
245
 
246
+ async def generate_video_from_pdf_bytes(self, pdf_bytes: bytes, language: str = "Japanese", voice_name: str = "Kore") -> Dict[str, Any]:
247
+ """Step 4, 5, 6: PDF bytes -> Video Pipeline using Async Gemini and Threads."""
 
 
 
 
 
248
  temp_dir = tempfile.mkdtemp(prefix="video_final_")
249
  try:
250
  pdf_path = os.path.join(temp_dir, "source.pdf")
251
+ with open(pdf_path, "wb") as f: f.write(pdf_bytes)
 
252
 
253
  # 1. Images
254
+ print("[SlidesVideo] Converting PDF to images...")
255
+ images = await asyncio.to_thread(convert_from_path, pdf_path, dpi=200)
256
  total_pages = len(images)
257
  image_paths = []
258
  for i, img in enumerate(images, start=1):
259
+ p = os.path.join(temp_dir, f"p_{i:02d}.png"); img.save(p, "PNG"); image_paths.append(p)
 
 
260
 
261
  # 2. Narration Script
262
+ print(f"[SlidesVideo] Generating script via OpenAI for {total_pages} pages...")
263
  with open(pdf_path, "rb") as f:
264
+ content = f.read()
265
+
266
+ openai_file = await self.openai_client.files.create(
267
+ file=("source.pdf", content),
268
+ purpose="assistants"
269
+ )
270
 
271
  prompt = get_video_script_prompt(language, total_pages)
272
+ resp = await self.openai_client.chat.completions.create(
273
  model="gpt-4o-mini",
274
  messages=[{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "file", "file": {"file_id": openai_file.id}}]}],
275
  response_format={"type": "json_object"},
 
277
  )
278
  script_data = json.loads(resp.choices[0].message.content)
279
  scripts = script_data.get("scripts", [])
280
+ await self.openai_client.files.delete(openai_file.id)
281
 
282
+ # 3. Audio & Video Assembly
283
  page_clips = []
284
  target_size = (1920, 1080)
 
285
  for i, img_path in enumerate(image_paths):
 
286
  if i < len(scripts) and i < len(image_paths) - 1:
287
  text = scripts[i].get("script_text", "")
288
  audio_path = os.path.join(temp_dir, f"a_{i}.wav")
289
+ print(f"[SlidesVideo] Generating TTS for slide {i}...")
 
290
  try:
291
+ tts_resp = await self.gemini_client.aio.models.generate_content(
292
+ model="gemini-2.5-flash-preview-tts",
 
 
293
  contents=text,
294
  config=types.GenerateContentConfig(
295
  response_modalities=["AUDIO"],
296
+ speech_config=types.SpeechConfig(voice_config=types.VoiceConfig(prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_name)))
 
 
 
 
 
 
297
  )
298
  )
299
+ except Exception as e:
300
+ print(f"[SlidesVideo] TTS fallback used for slide {i}: {e}")
301
+ tts_resp = await self.gemini_client.aio.models.generate_content(
302
+ model="gemini-1.5-flash",
 
303
  contents=text,
304
  config=types.GenerateContentConfig(
305
  response_modalities=["AUDIO"],
306
+ speech_config=types.SpeechConfig(voice_config=types.VoiceConfig(prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_name)))
 
 
 
 
 
 
307
  )
308
  )
309
  audio_data = tts_resp.candidates[0].content.parts[0].inline_data.data
310
+ with wave.open(audio_path, "wb") as wf: wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(24000); wf.writeframes(audio_data)
 
311
 
312
  aud_clip = AudioFileClip(audio_path)
313
+ img_clip = ImageClip(self._prepare_img(img_path, target_size, temp_dir, i), duration=aud_clip.duration)
 
314
  page_clips.append(img_clip.with_audio(aud_clip))
315
+ await asyncio.sleep(2)
316
  else:
317
+ page_clips.append(ImageClip(self._prepare_img(img_path, target_size, temp_dir, i), duration=3.0))
 
 
318
 
319
  final_path = os.path.join(temp_dir, "output.mp4")
320
  final_clip = concatenate_videoclips(page_clips, method="compose")
321
+ print("[SlidesVideo] Rendering final transformed video in background thread...")
322
+ await asyncio.to_thread(final_clip.write_videofile, final_path, fps=24, codec="libx264", audio_codec="aac", logger=None)
323
 
 
324
  for c in page_clips: c.close()
325
  final_clip.close()
326
 
327
+ ts = int(time.time()); s3_key = f"users/video_summaries/{ts}_summary.mp4"
328
+ print(f"[SlidesVideo] Uploading final transformed video to S3: {s3_key}")
329
+ await asyncio.to_thread(s3_service.s3_client.upload_file, final_path, settings.AWS_S3_BUCKET, s3_key)
330
+ return {"s3_key": s3_key, "s3_url": f"https://{settings.AWS_S3_BUCKET}.s3.{settings.AWS_REGION}.amazonaws.com/{s3_key}"}
 
 
 
 
331
  finally:
332
  shutil.rmtree(temp_dir, ignore_errors=True)
333
 
334
  def _prepare_img(self, path, size, temp_dir, idx):
335
+ img = Image.open(path); img.thumbnail(size, Image.Resampling.LANCZOS)
336
+ new_img = Image.new("RGB", size, (0, 0, 0)); new_img.paste(img, ((size[0] - img.size[0]) // 2, (size[1] - img.size[1]) // 2))
337
+ res_path = os.path.join(temp_dir, f"ready_{idx}.png"); new_img.save(res_path); return res_path
338
+
339
+ async def generate_transformed_video_summary(self, file_key: str, language: str = "Japanese", voice_name: str = "Kore", custom_prompt: str = "") -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  temp_dir = tempfile.mkdtemp(prefix="trans_video_")
341
  try:
 
342
  pdf_path = os.path.join(temp_dir, "input.pdf")
343
+ print(f"[SlidesVideo] Downloading source PDF: {file_key}")
344
+ await asyncio.to_thread(s3_service.s3_client.download_file, settings.AWS_S3_BUCKET, file_key, pdf_path)
 
 
345
  source_text = await self.extract_text_from_pdf(pdf_path)
 
 
 
346
  outline = await self.generate_outline(source_text, language, custom_prompt)
 
 
 
347
  transformed_pdf_bytes = await self.create_slides_and_export_pdf(outline)
 
 
 
348
  result = await self.generate_video_from_pdf_bytes(transformed_pdf_bytes, language, voice_name)
349
+ return {"title": f"Transformed Summary - {os.path.basename(file_key)}", "s3_key": result["s3_key"], "s3_url": result["s3_url"]}
 
 
 
 
 
 
350
  finally:
351
  shutil.rmtree(temp_dir, ignore_errors=True)
352
 
services/video_generator_service.py CHANGED
@@ -1,5 +1,5 @@
1
  import json
2
- import logging
3
  import os
4
  import tempfile
5
  import time
@@ -7,8 +7,7 @@ import shutil
7
  from typing import List, Dict, Optional, Any
8
  import wave
9
 
10
- import openai
11
- from google import genai
12
  from google.genai import types
13
  from PIL import Image
14
  from pdf2image import convert_from_path
@@ -17,15 +16,14 @@ from moviepy import ImageClip, AudioFileClip, VideoFileClip, concatenate_videocl
17
  from core.config import settings
18
  from core.prompts import get_video_script_prompt
19
  from services.s3_service import s3_service
20
-
21
- logger = logging.getLogger(__name__)
22
 
23
  class VideoGeneratorService:
24
  def __init__(self):
25
- self.openai_client = openai.OpenAI(api_key=settings.OPENAI_API_KEY)
26
 
27
  # Match Temp project: Use API Key for Gemini TTS
28
- logger.info("Initializing Gemini Client with API Key (as in Temp project)")
29
  self.gemini_client = genai.Client(api_key=settings.GEMINI_API_KEY)
30
 
31
  async def generate_video_summary(
@@ -40,19 +38,20 @@ class VideoGeneratorService:
40
  temp_dir = tempfile.mkdtemp(prefix="video_gen_")
41
  try:
42
  # 1. Download PDF from S3
 
43
  pdf_path = os.path.join(temp_dir, "input.pdf")
44
- s3_service.s3_client.download_file(settings.AWS_S3_BUCKET, file_key, pdf_path)
45
 
46
- # 2. Convert PDF to Images to get page count and for later use
47
  image_dir = os.path.join(temp_dir, "images")
48
  os.makedirs(image_dir, exist_ok=True)
49
 
50
- # Poppler check (Windows usually needs path)
51
  poppler_path = os.environ.get("POPPLER_PATH")
 
52
  if poppler_path:
53
- images = convert_from_path(pdf_path, dpi=200, poppler_path=poppler_path)
54
  else:
55
- images = convert_from_path(pdf_path, dpi=200)
56
 
57
  total_pages = len(images)
58
  image_paths = []
@@ -61,14 +60,18 @@ class VideoGeneratorService:
61
  img.save(img_path, "PNG")
62
  image_paths.append(img_path)
63
 
64
- # 3. Generate Narration Script (OpenAI)
 
65
  with open(pdf_path, "rb") as f:
66
- openai_file = self.openai_client.files.create(file=f, purpose="assistants")
 
 
 
 
 
67
 
68
- # Using the new high-fidelity prompt
69
  prompt = get_video_script_prompt(language, total_pages)
70
-
71
- response = self.openai_client.chat.completions.create(
72
  model="gpt-4o-mini",
73
  messages=[
74
  {
@@ -85,21 +88,16 @@ class VideoGeneratorService:
85
 
86
  script_data = json.loads(response.choices[0].message.content)
87
  scripts = script_data.get("scripts", [])
88
-
89
- # Cleanup OpenAI file
90
- self.openai_client.files.delete(openai_file.id)
91
 
92
- # 4. Generate Audio for each page (Gemini TTS)
93
  audio_dir = os.path.join(temp_dir, "audio")
94
  os.makedirs(audio_dir, exist_ok=True)
95
  audio_paths = []
96
 
97
- # We iterate through scripts. Usually total_pages.
98
- # Mirror original repo: last page (logo) is often skipped for audio.
99
  for i, script in enumerate(scripts):
100
- # If it's the last page, skip audio (standard behavior in the template project)
101
  if i == len(scripts) - 1:
102
- logger.info(f"Skipping audio for last page (logo slide)")
103
  continue
104
 
105
  page_num = script.get("page_number", i+1)
@@ -107,14 +105,12 @@ class VideoGeneratorService:
107
  if not text: continue
108
 
109
  audio_path = os.path.join(audio_dir, f"audio_{page_num:02d}.wav")
 
110
 
111
- # Gemini TTS with fallback
112
  try:
113
- # Default model from original repo
114
  model_name = "gemini-2.5-flash-preview-tts"
115
- logger.info(f"Generating audio for page {page_num} using {model_name}...")
116
-
117
- tts_resp = self.gemini_client.models.generate_content(
118
  model=model_name,
119
  contents=text,
120
  config=types.GenerateContentConfig(
@@ -129,11 +125,9 @@ class VideoGeneratorService:
129
  )
130
  )
131
  except Exception as tts_err:
132
- logger.warning(f"Failed with {model_name}, trying fallback gemini-1.5-flash: {tts_err}")
133
- # Fallback to a highly stable multimodal model
134
- model_name = "gemini-1.5-flash"
135
- tts_resp = self.gemini_client.models.generate_content(
136
- model=model_name,
137
  contents=text,
138
  config=types.GenerateContentConfig(
139
  response_modalities=["AUDIO"],
@@ -149,58 +143,52 @@ class VideoGeneratorService:
149
 
150
  audio_bytes = tts_resp.candidates[0].content.parts[0].inline_data.data
151
  with wave.open(audio_path, "wb") as wf:
152
- wf.setnchannels(1)
153
- wf.setsampwidth(2)
154
- wf.setframerate(24000)
155
- wf.writeframes(audio_bytes)
156
 
157
  audio_paths.append(audio_path)
158
- # Rate limiting guard: wait between audio gens
159
- time.sleep(3)
160
 
161
- # 5. Combine into individual videos and then final video (MoviePy)
 
162
  page_clips = []
163
  target_size = (1920, 1080)
164
 
165
  for i, img_path in enumerate(image_paths):
166
- # Match audio if available (some pages might not have script if script gen failed or skipped)
167
- # Usually we want 1 image per audio.
 
 
 
168
  if i < len(audio_paths):
169
  aud_clip = AudioFileClip(audio_paths[i])
170
- duration = aud_clip.duration
171
-
172
- # Process image to fit 1080p
173
- img = Image.open(img_path)
174
- img = self._resize_and_pad(img, target_size)
175
- temp_img_res = os.path.join(temp_dir, f"res_{i}.png")
176
- img.save(temp_img_res)
177
-
178
- img_clip = ImageClip(temp_img_res, duration=duration)
179
- vid_clip = img_clip.with_audio(aud_clip)
180
- page_clips.append(vid_clip)
181
  else:
182
- # Final page or extra pages - silent 3s
183
- img = Image.open(img_path)
184
- img = self._resize_and_pad(img, target_size)
185
- temp_img_res = os.path.join(temp_dir, f"res_{i}.png")
186
- img.save(temp_img_res)
187
- img_clip = ImageClip(temp_img_res, duration=3.0)
188
- page_clips.append(img_clip)
189
 
190
  final_video_path = os.path.join(temp_dir, "final.mp4")
191
  final_clip = concatenate_videoclips(page_clips, method="compose")
192
- final_clip.write_videofile(final_video_path, fps=24, codec="libx264", audio_codec="aac", logger=None)
 
 
 
 
 
 
 
 
193
 
194
- # Cleanup clips
195
  for clip in page_clips: clip.close()
196
- if final_clip: final_clip.close()
197
 
198
  # 6. Upload to S3
199
  timestamp = int(time.time())
200
  s3_key = f"users/video_summaries/{timestamp}_summary.mp4"
201
- s3_service.s3_client.upload_file(final_video_path, settings.AWS_S3_BUCKET, s3_key)
 
202
  s3_url = f"https://{settings.AWS_S3_BUCKET}.s3.{settings.AWS_REGION}.amazonaws.com/{s3_key}"
203
 
 
204
  return {
205
  "title": f"Video Summary - {os.path.basename(file_key)}",
206
  "s3_key": s3_key,
@@ -208,7 +196,7 @@ class VideoGeneratorService:
208
  }
209
 
210
  except Exception as e:
211
- logger.error(f"Video generation failed: {e}")
212
  import traceback
213
  traceback.print_exc()
214
  raise
 
1
  import json
2
+ import asyncio
3
  import os
4
  import tempfile
5
  import time
 
7
  from typing import List, Dict, Optional, Any
8
  import wave
9
 
10
+ from openai import AsyncOpenAI
 
11
  from google.genai import types
12
  from PIL import Image
13
  from pdf2image import convert_from_path
 
16
  from core.config import settings
17
  from core.prompts import get_video_script_prompt
18
  from services.s3_service import s3_service
19
+ from google import genai
 
20
 
21
  class VideoGeneratorService:
22
  def __init__(self):
23
+ self.openai_client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
24
 
25
  # Match Temp project: Use API Key for Gemini TTS
26
+ print("[VideoGenerator] Initializing Gemini Client with API Key")
27
  self.gemini_client = genai.Client(api_key=settings.GEMINI_API_KEY)
28
 
29
  async def generate_video_summary(
 
38
  temp_dir = tempfile.mkdtemp(prefix="video_gen_")
39
  try:
40
  # 1. Download PDF from S3
41
+ print(f"[VideoGenerator] Starting generation for: {file_key}")
42
  pdf_path = os.path.join(temp_dir, "input.pdf")
43
+ await asyncio.to_thread(s3_service.s3_client.download_file, settings.AWS_S3_BUCKET, file_key, pdf_path)
44
 
45
+ # 2. Convert PDF to Images
46
  image_dir = os.path.join(temp_dir, "images")
47
  os.makedirs(image_dir, exist_ok=True)
48
 
 
49
  poppler_path = os.environ.get("POPPLER_PATH")
50
+ print("[VideoGenerator] Converting PDF to images...")
51
  if poppler_path:
52
+ images = await asyncio.to_thread(convert_from_path, pdf_path, dpi=200, poppler_path=poppler_path)
53
  else:
54
+ images = await asyncio.to_thread(convert_from_path, pdf_path, dpi=200)
55
 
56
  total_pages = len(images)
57
  image_paths = []
 
60
  img.save(img_path, "PNG")
61
  image_paths.append(img_path)
62
 
63
+ # 3. Generate Narration Script (Native Async OpenAI)
64
+ print(f"[VideoGenerator] Generating script with OpenAI for {total_pages} pages...")
65
  with open(pdf_path, "rb") as f:
66
+ content = f.read()
67
+
68
+ openai_file = await self.openai_client.files.create(
69
+ file=("source.pdf", content),
70
+ purpose="assistants"
71
+ )
72
 
 
73
  prompt = get_video_script_prompt(language, total_pages)
74
+ response = await self.openai_client.chat.completions.create(
 
75
  model="gpt-4o-mini",
76
  messages=[
77
  {
 
88
 
89
  script_data = json.loads(response.choices[0].message.content)
90
  scripts = script_data.get("scripts", [])
91
+ await self.openai_client.files.delete(openai_file.id)
 
 
92
 
93
+ # 4. Generate Audio (Native Async Gemini)
94
  audio_dir = os.path.join(temp_dir, "audio")
95
  os.makedirs(audio_dir, exist_ok=True)
96
  audio_paths = []
97
 
 
 
98
  for i, script in enumerate(scripts):
 
99
  if i == len(scripts) - 1:
100
+ print("[VideoGenerator] Skipping audio for last page (logo slide)")
101
  continue
102
 
103
  page_num = script.get("page_number", i+1)
 
105
  if not text: continue
106
 
107
  audio_path = os.path.join(audio_dir, f"audio_{page_num:02d}.wav")
108
+ print(f"[VideoGenerator] Generating TTS for page {page_num}...")
109
 
 
110
  try:
111
+ # Use Native Async Gemini
112
  model_name = "gemini-2.5-flash-preview-tts"
113
+ tts_resp = await self.gemini_client.aio.models.generate_content(
 
 
114
  model=model_name,
115
  contents=text,
116
  config=types.GenerateContentConfig(
 
125
  )
126
  )
127
  except Exception as tts_err:
128
+ print(f"[VideoGenerator] TTS Primary failed, using fallback: {tts_err}")
129
+ tts_resp = await self.gemini_client.aio.models.generate_content(
130
+ model="gemini-1.5-flash",
 
 
131
  contents=text,
132
  config=types.GenerateContentConfig(
133
  response_modalities=["AUDIO"],
 
143
 
144
  audio_bytes = tts_resp.candidates[0].content.parts[0].inline_data.data
145
  with wave.open(audio_path, "wb") as wf:
146
+ wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(24000); wf.writeframes(audio_bytes)
 
 
 
147
 
148
  audio_paths.append(audio_path)
149
+ await asyncio.sleep(2) # Non-blocking sleep
 
150
 
151
+ # 5. Assembly (MoviePy in Thread)
152
+ print("[VideoGenerator] Assembled audio/images. Now rendering final video with MoviePy (background thread)...")
153
  page_clips = []
154
  target_size = (1920, 1080)
155
 
156
  for i, img_path in enumerate(image_paths):
157
+ img = Image.open(img_path)
158
+ img = self._resize_and_pad(img, target_size)
159
+ temp_img_res = os.path.join(temp_dir, f"res_{i}.png")
160
+ img.save(temp_img_res)
161
+
162
  if i < len(audio_paths):
163
  aud_clip = AudioFileClip(audio_paths[i])
164
+ img_clip = ImageClip(temp_img_res, duration=aud_clip.duration)
165
+ page_clips.append(img_clip.with_audio(aud_clip))
 
 
 
 
 
 
 
 
 
166
  else:
167
+ page_clips.append(ImageClip(temp_img_res, duration=3.0))
 
 
 
 
 
 
168
 
169
  final_video_path = os.path.join(temp_dir, "final.mp4")
170
  final_clip = concatenate_videoclips(page_clips, method="compose")
171
+
172
+ await asyncio.to_thread(
173
+ final_clip.write_videofile,
174
+ final_video_path,
175
+ fps=24,
176
+ codec="libx264",
177
+ audio_codec="aac",
178
+ logger=None
179
+ )
180
 
 
181
  for clip in page_clips: clip.close()
182
+ final_clip.close()
183
 
184
  # 6. Upload to S3
185
  timestamp = int(time.time())
186
  s3_key = f"users/video_summaries/{timestamp}_summary.mp4"
187
+ print(f"[VideoGenerator] Uploading final video to S3: {s3_key}")
188
+ await asyncio.to_thread(s3_service.s3_client.upload_file, final_video_path, settings.AWS_S3_BUCKET, s3_key)
189
  s3_url = f"https://{settings.AWS_S3_BUCKET}.s3.{settings.AWS_REGION}.amazonaws.com/{s3_key}"
190
 
191
+ print(f"[VideoGenerator] Success! Video ready at: {s3_url}")
192
  return {
193
  "title": f"Video Summary - {os.path.basename(file_key)}",
194
  "s3_key": s3_key,
 
196
  }
197
 
198
  except Exception as e:
199
+ print(f"[VideoGenerator] ERROR: {str(e)}")
200
  import traceback
201
  traceback.print_exc()
202
  raise