matsuap commited on
Commit
951d5c6
·
verified ·
1 Parent(s): 4064f62

Upload folder using huggingface_hub

Browse files
api/auth.py CHANGED
@@ -33,6 +33,30 @@ async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = De
33
  raise credentials_exception
34
  return user
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @router.post("/register", response_model=UserResponse)
37
  async def register(user_in: UserCreate, db: Session = Depends(get_db)):
38
  db_user = db.query(db_models.User).filter(db_models.User.email == user_in.email).first()
 
33
  raise credentials_exception
34
  return user
35
 
36
+
37
+ async def get_current_user_ws(token: str, db: Session):
38
+ """
39
+ WebSocket authentication - validates JWT token passed as query parameter.
40
+ Raises HTTPException if authentication fails.
41
+ """
42
+ credentials_exception = HTTPException(
43
+ status_code=status.HTTP_401_UNAUTHORIZED,
44
+ detail="Could not validate credentials",
45
+ )
46
+ try:
47
+ payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
48
+ email: str = payload.get("sub")
49
+ if email is None:
50
+ raise credentials_exception
51
+ token_data = TokenData(email=email)
52
+ except JWTError:
53
+ raise credentials_exception
54
+
55
+ user = db.query(db_models.User).filter(db_models.User.email == token_data.email).first()
56
+ if user is None:
57
+ raise credentials_exception
58
+ return user
59
+
60
  @router.post("/register", response_model=UserResponse)
61
  async def register(user_in: UserCreate, db: Session = Depends(get_db)):
62
  db_user = db.query(db_models.User).filter(db_models.User.email == user_in.email).first()
api/podcast.py CHANGED
@@ -1,14 +1,16 @@
1
  import os
 
2
  import logging
3
  from datetime import datetime
4
- from fastapi import APIRouter, Depends, HTTPException
5
  from sqlalchemy.orm import Session
6
  from typing import Dict, List
 
7
 
8
  from api.auth import get_current_user
9
- from models.schemas import PodcastAnalyzeRequest, PodcastGenerateRequest
10
  from models import db_models
11
- from core.database import get_db
12
  from services.podcast_service import podcast_service
13
  from services.s3_service import s3_service
14
  from core import constants
@@ -27,57 +29,29 @@ async def get_podcast_config():
27
  "models": constants.PODCAST_MODALS
28
  }
29
 
30
- @router.post("/analyze")
31
- async def analyze_source(
32
- request: PodcastAnalyzeRequest,
33
- current_user: db_models.User = Depends(get_current_user),
34
- db: Session = Depends(get_db)):
35
- """
36
- Analyzes a source file from S3 and proposes podcast structures.
37
- """
38
  try:
39
- # Verify file ownership via DB
40
- source = db.query(db_models.Source).filter(
41
- db_models.Source.s3_key == request.file_key,
42
- db_models.Source.user_id == current_user.id
43
- ).first()
44
-
45
- if not source:
46
- raise HTTPException(status_code=403, detail="Not authorized to access this file or file does not exist")
47
 
48
- analysis = await podcast_service.analyze_pdf(
49
- file_key=request.file_key,
50
- duration_minutes=request.duration_minutes
51
- )
52
- return {"analysis": analysis}
53
- except HTTPException:
54
- raise
55
- except Exception as e:
56
- logger.error(f"Analysis failed: {e}")
57
- raise HTTPException(status_code=500, detail=str(e))
58
 
59
- @router.post("/generate")
60
- async def generate_podcast(
61
- request: PodcastGenerateRequest,
62
- current_user: db_models.User = Depends(get_current_user),
63
- db: Session = Depends(get_db)
64
- ):
65
- """
66
- Generates a podcast script and then the audio.
67
- Saves metadata to DB and returns the generated info.
68
- """
69
- try:
70
- # 1. Verify file ownership if provided
71
  if request.file_key:
72
- source = db.query(db_models.Source).filter(
73
- db_models.Source.s3_key == request.file_key,
74
- db_models.Source.user_id == current_user.id
75
- ).first()
76
- if not source:
77
- raise HTTPException(status_code=403, detail="Not authorized to access this file")
78
- source_id = source.id
79
- else:
80
- source_id = None
81
 
82
  # 2. Generate Script
83
  script = await podcast_service.generate_script(
@@ -85,12 +59,14 @@ async def generate_podcast(
85
  model=request.model,
86
  duration_minutes=request.duration_minutes,
87
  podcast_format=request.podcast_format,
88
- pdf_suggestions=request.pdf_suggestions,
89
  file_key=request.file_key
90
  )
91
 
92
  if not script:
93
- raise HTTPException(status_code=500, detail="Failed to generate script")
 
 
94
 
95
  # 3. Generate Audio
96
  audio_path = await podcast_service.generate_full_audio(
@@ -103,83 +79,115 @@ async def generate_podcast(
103
  )
104
 
105
  if not audio_path:
106
- raise HTTPException(status_code=500, detail="Failed to generate audio")
 
 
107
 
108
  # 4. Upload to S3
109
  filename = os.path.basename(audio_path)
110
- with open(audio_path, "rb") as f:
111
- content = f.read()
112
-
113
- s3_key = f"users/{current_user.id}/outputs/podcasts/{filename}"
114
-
115
- import boto3
116
- from core.config import settings
117
- s3_client = boto3.client('s3',
118
- aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
119
- aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
120
- region_name=settings.AWS_REGION)
121
- s3_client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content)
122
-
123
- public_url = s3_service.get_public_url(s3_key)
124
- private_url = s3_service.get_presigned_url(s3_key)
 
125
 
126
- # 5. Save to DB
127
- db_podcast = db_models.Podcast(
128
- title=f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
129
- s3_key=s3_key,
130
- s3_url=public_url,
131
- script=script,
132
- user_id=current_user.id,
133
- source_id=source_id
134
- )
135
- db.add(db_podcast)
136
  db.commit()
137
- db.refresh(db_podcast)
138
 
139
- # Clean up local file
140
- os.remove(audio_path)
 
 
 
 
 
141
 
142
- return {
143
- "id": db_podcast.id,
144
- "message": "Podcast generated successfully",
145
- "script": script,
146
- "public_url": public_url,
147
- "private_url": private_url
148
- }
149
 
150
- except HTTPException:
151
- raise
152
  except Exception as e:
153
- logger.error(f"Podcast generation failed: {e}")
154
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
155
 
156
- @router.get("/list")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  async def list_podcasts(
158
  current_user: db_models.User = Depends(get_current_user),
159
  db: Session = Depends(get_db)
160
  ):
161
  """
162
- Lists all podcasts for the current user.
163
  """
164
  try:
165
  podcasts = db.query(db_models.Podcast).filter(
166
  db_models.Podcast.user_id == current_user.id
167
  ).order_by(db_models.Podcast.created_at.desc()).all()
168
 
169
- return [
170
- {
171
- "id": p.id,
172
- "title": p.title,
173
- "s3_key": p.s3_key,
174
- "public_url": p.s3_url,
175
- "private_url": s3_service.get_presigned_url(p.s3_key),
176
- "script_preview": (p.script[:200] + "...") if p.script else "",
177
- "parent_file_id": p.source_id,
178
- "parent_file_key": p.source.s3_key if p.source else None,
179
- "created_at": p.created_at
180
- }
181
- for p in podcasts
182
- ]
183
  except Exception as e:
184
  raise HTTPException(status_code=500, detail=str(e))
185
 
@@ -201,8 +209,9 @@ async def delete_podcast(
201
  raise HTTPException(status_code=404, detail="Podcast not found")
202
 
203
  try:
204
- # 1. Delete from S3
205
- await s3_service.delete_file(podcast.s3_key)
 
206
 
207
  # 2. Delete from DB
208
  db.delete(podcast)
 
1
  import os
2
+ import asyncio
3
  import logging
4
  from datetime import datetime
5
+ from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
6
  from sqlalchemy.orm import Session
7
  from typing import Dict, List
8
+ from api.websocket_routes import manager
9
 
10
  from api.auth import get_current_user
11
+ from models.schemas import PodcastGenerateRequest, PodcastResponse
12
  from models import db_models
13
+ from core.database import get_db, SessionLocal
14
  from services.podcast_service import podcast_service
15
  from services.s3_service import s3_service
16
  from core import constants
 
29
  "models": constants.PODCAST_MODALS
30
  }
31
 
32
+ async def run_podcast_generation(podcast_id: int, request: PodcastGenerateRequest, user_id: int):
33
+ """Background task to generate podcast and update status."""
34
+ db = SessionLocal()
 
 
 
 
 
35
  try:
36
+ podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first()
37
+ if not podcast:
38
+ return
 
 
 
 
 
39
 
40
+ podcast.status = "processing"
41
+ db.commit()
42
+
43
+ # Notify via WebSocket if connected
44
+ connection_id = f"user_{user_id}"
45
+ await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...")
 
 
 
 
46
 
47
+ # 1. Analyze first if file is provided
48
+ analysis_report = ""
 
 
 
 
 
 
 
 
 
 
49
  if request.file_key:
50
+ analysis_report = await podcast_service.analyze_pdf(
51
+ file_key=request.file_key,
52
+ duration_minutes=request.duration_minutes
53
+ )
54
+ await manager.send_progress(connection_id, 20, "processing", "Generating podcast script...")
 
 
 
 
55
 
56
  # 2. Generate Script
57
  script = await podcast_service.generate_script(
 
59
  model=request.model,
60
  duration_minutes=request.duration_minutes,
61
  podcast_format=request.podcast_format,
62
+ pdf_suggestions=analysis_report,
63
  file_key=request.file_key
64
  )
65
 
66
  if not script:
67
+ raise Exception("Failed to generate script")
68
+
69
+ await manager.send_progress(connection_id, 40, "processing", "Generating audio (this may take several minutes)...")
70
 
71
  # 3. Generate Audio
72
  audio_path = await podcast_service.generate_full_audio(
 
79
  )
80
 
81
  if not audio_path:
82
+ raise Exception("Failed to generate audio")
83
+
84
+ await manager.send_progress(connection_id, 85, "processing", "Uploading to S3...")
85
 
86
  # 4. Upload to S3
87
  filename = os.path.basename(audio_path)
88
+ s3_key = f"users/{user_id}/outputs/podcasts/{filename}"
89
+
90
+ def upload_audio():
91
+ with open(audio_path, "rb") as f:
92
+ content = f.read()
93
+
94
+ import boto3
95
+ from core.config import settings
96
+ s3_client = boto3.client('s3',
97
+ aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
98
+ aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
99
+ region_name=settings.AWS_REGION)
100
+ s3_client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content)
101
+ return content
102
+
103
+ await asyncio.to_thread(upload_audio)
104
 
105
+ public_url = s3_service.get_public_url(s3_key)
106
+
107
+ # 5. Final update to DB
108
+ podcast.s3_key = s3_key
109
+ podcast.s3_url = public_url
110
+ podcast.script = script
111
+ podcast.status = "completed"
 
 
 
112
  db.commit()
 
113
 
114
+ # Notify completion
115
+ await manager.send_result(connection_id, {
116
+ "id": podcast.id,
117
+ "status": "completed",
118
+ "title": podcast.title,
119
+ "public_url": public_url
120
+ })
121
 
122
+ # Clean up
123
+ if os.path.exists(audio_path):
124
+ os.remove(audio_path)
 
 
 
 
125
 
 
 
126
  except Exception as e:
127
+ logger.error(f"Background podcast generation failed for ID {podcast_id}: {e}")
128
+ podcast = db.query(db_models.Podcast).filter(db_models.Podcast.id == podcast_id).first()
129
+ if podcast:
130
+ podcast.status = "failed"
131
+ podcast.error_message = str(e)
132
+ db.commit()
133
+
134
+ connection_id = f"user_{user_id}"
135
+ await manager.send_error(connection_id, f"Generation failed: {str(e)}")
136
+ finally:
137
+ db.close()
138
 
139
+ @router.post("/generate", response_model=PodcastResponse)
140
+ async def generate_podcast(
141
+ request: PodcastGenerateRequest,
142
+ background_tasks: BackgroundTasks,
143
+ current_user: db_models.User = Depends(get_current_user),
144
+ db: Session = Depends(get_db)
145
+ ):
146
+ """
147
+ Initiates podcast generation in the background.
148
+ Creates a 'pending' record immediately and returns it.
149
+ """
150
+ # 1. Verify file ownership if provided
151
+ source_id = None
152
+ if request.file_key:
153
+ source = db.query(db_models.Source).filter(
154
+ db_models.Source.s3_key == request.file_key,
155
+ db_models.Source.user_id == current_user.id
156
+ ).first()
157
+ if not source:
158
+ raise HTTPException(status_code=403, detail="Not authorized to access this file")
159
+ source_id = source.id
160
+
161
+ # 2. Create pending record
162
+ db_podcast = db_models.Podcast(
163
+ title=f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
164
+ user_id=current_user.id,
165
+ source_id=source_id,
166
+ status="processing"
167
+ )
168
+ db.add(db_podcast)
169
+ db.commit()
170
+ db.refresh(db_podcast)
171
+
172
+ # 3. Add to background tasks
173
+ background_tasks.add_task(run_podcast_generation, db_podcast.id, request, current_user.id)
174
+
175
+ return db_podcast
176
+
177
+ @router.get("/list", response_model=List[PodcastResponse])
178
  async def list_podcasts(
179
  current_user: db_models.User = Depends(get_current_user),
180
  db: Session = Depends(get_db)
181
  ):
182
  """
183
+ Lists all podcasts for the current user including their generation status.
184
  """
185
  try:
186
  podcasts = db.query(db_models.Podcast).filter(
187
  db_models.Podcast.user_id == current_user.id
188
  ).order_by(db_models.Podcast.created_at.desc()).all()
189
 
190
+ return [PodcastResponse.model_validate(p) for p in podcasts]
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  except Exception as e:
192
  raise HTTPException(status_code=500, detail=str(e))
193
 
 
209
  raise HTTPException(status_code=404, detail="Podcast not found")
210
 
211
  try:
212
+ # 1. Delete from S3 if it exists
213
+ if podcast.s3_key:
214
+ await s3_service.delete_file(podcast.s3_key)
215
 
216
  # 2. Delete from DB
217
  db.delete(podcast)
api/sources.py CHANGED
@@ -130,8 +130,9 @@ async def delete_source(
130
 
131
  db.commit() # Commit deletions
132
 
133
- # 3. Delete from S3
134
- await s3_service.delete_file(source.s3_key)
 
135
 
136
  # 4. Delete the Source itself from Database
137
  db.delete(source)
 
130
 
131
  db.commit() # Commit deletions
132
 
133
+ # 3. Delete from S3 if it exists
134
+ if source.s3_key:
135
+ await s3_service.delete_file(source.s3_key)
136
 
137
  # 4. Delete the Source itself from Database
138
  db.delete(source)
api/video_generator.py CHANGED
@@ -123,8 +123,9 @@ async def delete_video_summary(
123
  raise HTTPException(status_code=404, detail="Video summary not found")
124
 
125
  try:
126
- # 1. Delete from S3
127
- await s3_service.delete_file(summary.s3_key)
 
128
 
129
  # 2. Delete from DB
130
  db.delete(summary)
 
123
  raise HTTPException(status_code=404, detail="Video summary not found")
124
 
125
  try:
126
+ # 1. Delete from S3 if it exists
127
+ if summary.s3_key:
128
+ await s3_service.delete_file(summary.s3_key)
129
 
130
  # 2. Delete from DB
131
  db.delete(summary)
api/websocket_routes.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
4
+ from sqlalchemy.orm import Session
5
+ from datetime import datetime
6
+ from typing import Dict, Any
7
+
8
+ from api.auth import get_current_user_ws
9
+ from models import db_models
10
+ from core.database import get_db
11
+ from services.flashcard_service import flashcard_service
12
+ from services.quiz_service import quiz_service
13
+ from services.report_service import report_service
14
+ from services.mindmap_service import mindmap_service
15
+ from services.podcast_service import podcast_service
16
+ from services.s3_service import s3_service
17
+ from services.video_generator_service import video_generator_service
18
+
19
+ router = APIRouter(prefix="/ws", tags=["websockets"])
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ConnectionManager:
24
+ """Manages WebSocket connections for parallel execution"""
25
+
26
+ def __init__(self):
27
+ self.active_connections: Dict[str, WebSocket] = {}
28
+
29
+ async def connect(self, websocket: WebSocket, connection_id: str):
30
+ await websocket.accept()
31
+ self.active_connections[connection_id] = websocket
32
+ logger.info(f"WebSocket connected: {connection_id}")
33
+
34
+ def disconnect(self, connection_id: str):
35
+ if connection_id in self.active_connections:
36
+ del self.active_connections[connection_id]
37
+ logger.info(f"WebSocket disconnected: {connection_id}")
38
+
39
+ async def send_progress(self, connection_id: str, progress: int, status: str, message: str = ""):
40
+ if connection_id in self.active_connections:
41
+ try:
42
+ await self.active_connections[connection_id].send_json({
43
+ "type": "progress",
44
+ "progress": progress,
45
+ "status": status,
46
+ "message": message
47
+ })
48
+ except Exception as e:
49
+ logger.error(f"Error sending progress to {connection_id}: {e}")
50
+
51
+ async def send_result(self, connection_id: str, data: Any):
52
+ if connection_id in self.active_connections:
53
+ try:
54
+ await self.active_connections[connection_id].send_json({
55
+ "type": "result",
56
+ "status": "complete",
57
+ "progress": 100,
58
+ "data": data
59
+ })
60
+ except Exception as e:
61
+ logger.error(f"Error sending result to {connection_id}: {e}")
62
+
63
+ async def send_error(self, connection_id: str, error: str):
64
+ if connection_id in self.active_connections:
65
+ try:
66
+ await self.active_connections[connection_id].send_json({
67
+ "type": "error",
68
+ "status": "error",
69
+ "message": error
70
+ })
71
+ except Exception as e:
72
+ logger.error(f"Error sending error to {connection_id}: {e}")
73
+
74
+
75
+ manager = ConnectionManager()
76
+
77
+
78
+ @router.websocket("/generate")
79
+ async def unified_generate_ws(
80
+ websocket: WebSocket,
81
+ token: str,
82
+ db: Session = Depends(get_db)):
83
+ """
84
+ Unified WebSocket gateway for all generation tasks.
85
+ Client sends JSON: { "type": "podcast|flashcards|quiz|mindmap|report|video", "data": { ... } }
86
+ """
87
+ await websocket.accept()
88
+
89
+ try:
90
+ current_user = await get_current_user_ws(token, db)
91
+ connection_id = f"user_{current_user.id}"
92
+ manager.active_connections[connection_id] = websocket
93
+
94
+ # Receive the task specification
95
+ message = await websocket.receive_json()
96
+ task_type = message.get("type")
97
+ data = message.get("data", {})
98
+
99
+ if not task_type:
100
+ await manager.send_error(connection_id, "Missing 'type' in request")
101
+ return
102
+
103
+ await manager.send_progress(connection_id, 2, "processing", f"Initializing {task_type} task...")
104
+
105
+ # --- ROUTING LOGIC ---
106
+
107
+ if task_type == "podcast":
108
+ await handle_podcast_task(connection_id, data, current_user, db)
109
+
110
+ elif task_type == "flashcards":
111
+ await handle_flashcards_task(connection_id, data, current_user, db)
112
+
113
+ elif task_type == "quiz":
114
+ await handle_quiz_task(connection_id, data, current_user, db)
115
+
116
+ else:
117
+ await manager.send_error(connection_id, f"Unsupported task type: {task_type}")
118
+
119
+ except WebSocketDisconnect:
120
+ logger.info(f"Client disconnected")
121
+ except Exception as e:
122
+ logger.error(f"Unified WebSocket error: {e}")
123
+ try:
124
+ await manager.send_error(connection_id, str(e))
125
+ except: pass
126
+ finally:
127
+ if 'connection_id' in locals():
128
+ manager.disconnect(connection_id)
129
+
130
+ async def handle_podcast_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
131
+ """Internal handler for podcast generation"""
132
+ try:
133
+ source_id = None
134
+ if data.get("file_key"):
135
+ source = db.query(db_models.Source).filter(
136
+ db_models.Source.s3_key == data["file_key"],
137
+ db_models.Source.user_id == current_user.id
138
+ ).first()
139
+ if not source:
140
+ await manager.send_error(connection_id, "Not authorized to access this file")
141
+ return
142
+ source_id = source.id
143
+
144
+ db_podcast = db_models.Podcast(
145
+ title=f"Podcast {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}",
146
+ user_id=current_user.id,
147
+ source_id=source_id,
148
+ status="processing"
149
+ )
150
+ db.add(db_podcast)
151
+ db.commit()
152
+ db.refresh(db_podcast)
153
+
154
+ db_podcast.status = "processing"
155
+ db.commit()
156
+
157
+ analysis_report = ""
158
+ if data.get("file_key"):
159
+ await manager.send_progress(connection_id, 10, "processing", "Analyzing source file...")
160
+ analysis_report = await podcast_service.analyze_pdf(
161
+ file_key=data["file_key"],
162
+ duration_minutes=data.get("duration_minutes", 10)
163
+ )
164
+
165
+ await manager.send_progress(connection_id, 15, "processing", "Generating podcast script...")
166
+ script = await podcast_service.generate_script(
167
+ user_prompt=data["user_prompt"],
168
+ model=data.get("model", "gpt-4o"),
169
+ duration_minutes=data.get("duration_minutes", 10),
170
+ podcast_format=data.get("podcast_format", "conversational"),
171
+ pdf_suggestions=analysis_report,
172
+ file_key=data.get("file_key")
173
+ )
174
+
175
+ if not script: raise Exception("Failed to generate script")
176
+
177
+ await manager.send_progress(connection_id, 45, "processing", "Generating audio...")
178
+ audio_path = await podcast_service.generate_full_audio(
179
+ script=script,
180
+ tts_model=data.get("tts_model", "gemini-2.0-flash-exp"),
181
+ spk1_voice=data.get("spk1_voice", "Puck"),
182
+ spk2_voice=data.get("spk2_voice", "Charon"),
183
+ temperature=data.get("temperature", 1.0),
184
+ bgm_choice=data.get("bgm_choice", "No BGM")
185
+ )
186
+
187
+ if not audio_path: raise Exception("Failed to generate audio")
188
+
189
+ await manager.send_progress(connection_id, 90, "processing", "Uploading to S3...")
190
+ import os
191
+ filename = os.path.basename(audio_path)
192
+ s3_key = f"users/{current_user.id}/outputs/podcasts/{filename}"
193
+
194
+ def upload_audio_sync():
195
+ with open(audio_path, "rb") as f:
196
+ content = f.read()
197
+ import boto3
198
+ from core.config import settings
199
+ boto3.client('s3',
200
+ aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
201
+ aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
202
+ region_name=settings.AWS_REGION).put_object(Bucket=settings.AWS_S3_BUCKET, Key=s3_key, Body=content)
203
+
204
+ await asyncio.to_thread(upload_audio_sync)
205
+
206
+ public_url = s3_service.get_public_url(s3_key)
207
+
208
+ db_podcast.s3_key = s3_key
209
+ db_podcast.s3_url = public_url
210
+ db_podcast.script = script
211
+ db_podcast.status = "completed"
212
+ db.commit()
213
+
214
+ if os.path.exists(audio_path): os.remove(audio_path)
215
+
216
+ await manager.send_result(connection_id, {
217
+ "id": db_podcast.id,
218
+ "status": "completed",
219
+ "message": "Podcast generated successfully",
220
+ "public_url": public_url
221
+ })
222
+
223
+ except Exception as e:
224
+ logger.error(f"Podcast task failed: {e}")
225
+ if 'db_podcast' in locals():
226
+ db_podcast.status = "failed"
227
+ db_podcast.error_message = str(e)
228
+ db.commit()
229
+ await manager.send_error(connection_id, str(e))
230
+
231
+ async def handle_flashcards_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
232
+ """Internal handler for flashcard generation"""
233
+ try:
234
+ source_id = None
235
+ source = None
236
+ if data.get("file_key"):
237
+ source = db.query(db_models.Source).filter(
238
+ db_models.Source.s3_key == data["file_key"],
239
+ db_models.Source.user_id == current_user.id
240
+ ).first()
241
+ if not source:
242
+ await manager.send_error(connection_id, "Not authorized to access this file")
243
+ return
244
+ source_id = source.id
245
+
246
+ await manager.send_progress(connection_id, 10, "processing", "Generating flashcards...")
247
+
248
+ cards_data = await flashcard_service.generate_flashcards(
249
+ file_key=data.get("file_key"),
250
+ text_input=data.get("text_input"),
251
+ difficulty=data.get("difficulty", "medium"),
252
+ quantity=data.get("quantity", "standard"),
253
+ topic=data.get("topic"),
254
+ language=data.get("language", "English"),
255
+ progress_callback=lambda p, m: asyncio.create_task(
256
+ manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m)
257
+ )
258
+ )
259
+
260
+ if not cards_data:
261
+ await manager.send_error(connection_id, "AI returned an empty response")
262
+ return
263
+
264
+ await manager.send_progress(connection_id, 85, "processing", "Saving to database...")
265
+
266
+ title = data.get("topic", f"Flashcards {len(cards_data)}")
267
+ db_set = db_models.FlashcardSet(
268
+ title=title,
269
+ difficulty=data.get("difficulty", "medium"),
270
+ user_id=current_user.id,
271
+ source_id=source_id
272
+ )
273
+ db.add(db_set)
274
+ db.commit()
275
+ db.refresh(db_set)
276
+
277
+ for item in cards_data:
278
+ db_card = db_models.Flashcard(
279
+ flashcard_set_id=db_set.id,
280
+ question=item.get("question", ""),
281
+ answer=item.get("answer", "")
282
+ )
283
+ db.add(db_card)
284
+
285
+ db.commit()
286
+ db.refresh(db_set)
287
+
288
+ await manager.send_result(connection_id, {
289
+ "id": db_set.id,
290
+ "title": db_set.title,
291
+ "flashcards_count": len(db_set.flashcards)
292
+ })
293
+ except Exception as e:
294
+ logger.error(f"Flashcard task failed: {e}")
295
+ await manager.send_error(connection_id, str(e))
296
+
297
+ async def handle_quiz_task(connection_id: str, data: Dict, current_user: db_models.User, db: Session):
298
+ """Internal handler for quiz generation"""
299
+ try:
300
+ source_id = None
301
+ if data.get("file_key"):
302
+ source = db.query(db_models.Source).filter(
303
+ db_models.Source.s3_key == data["file_key"],
304
+ db_models.Source.user_id == current_user.id
305
+ ).first()
306
+ if not source:
307
+ await manager.send_error(connection_id, "Not authorized to access this file")
308
+ return
309
+ source_id = source.id
310
+
311
+ await manager.send_progress(connection_id, 10, "processing", "Generating quiz...")
312
+
313
+ quizzes_data = await quiz_service.generate_quiz(
314
+ file_key=data.get("file_key"),
315
+ text_input=data.get("text_input"),
316
+ difficulty=data.get("difficulty", "medium"),
317
+ topic=data.get("topic"),
318
+ language=data.get("language", "English"),
319
+ count_mode=data.get("count", "STANDARD"),
320
+ progress_callback=lambda p, m: asyncio.create_task(
321
+ manager.send_progress(connection_id, 10 + int(p * 0.7), "processing", m)
322
+ )
323
+ )
324
+
325
+ if not quizzes_data:
326
+ await manager.send_error(connection_id, "Failed to generate quiz")
327
+ return
328
+
329
+ db_set = db_models.QuizSet(
330
+ title=data.get("topic", "Quiz"),
331
+ difficulty=data.get("difficulty", "medium"),
332
+ user_id=current_user.id,
333
+ source_id=source_id
334
+ )
335
+ db.add(db_set)
336
+ db.commit()
337
+ db.refresh(db_set)
338
+
339
+ for item in quizzes_data:
340
+ db_question = db_models.QuizQuestion(
341
+ quiz_set_id=db_set.id,
342
+ question=item.get("question", ""),
343
+ choices=item.get("choices", {}),
344
+ answer=item.get("answer", "1"),
345
+ explanation=item.get("explanation", "")
346
+ )
347
+ db.add(db_question)
348
+
349
+ db.commit()
350
+ await manager.send_result(connection_id, {"id": db_set.id, "title": db_set.title})
351
+ except Exception as e:
352
+ logger.error(f"Quiz task failed: {e}")
353
+ await manager.send_error(connection_id, str(e))
354
+
main.py CHANGED
@@ -1,7 +1,7 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from core.database import init_db
4
- from api import auth, sources, podcast, flashcards, mindmaps, quizzes, reports, video_generator, rag, chat
5
 
6
  # Initialize Database Tables
7
  init_db()
@@ -32,6 +32,7 @@ app.include_router(reports.router)
32
  app.include_router(video_generator.router)
33
  app.include_router(rag.router)
34
  app.include_router(chat.router)
 
35
 
36
  @app.get("/")
37
  async def root():
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from core.database import init_db
4
+ from api import auth, sources, podcast, flashcards, mindmaps, quizzes, reports, video_generator, rag, chat, websocket_routes
5
 
6
  # Initialize Database Tables
7
  init_db()
 
32
  app.include_router(video_generator.router)
33
  app.include_router(rag.router)
34
  app.include_router(chat.router)
35
+ app.include_router(websocket_routes.router) # WebSocket endpoints for real-time progress
36
 
37
  @app.get("/")
38
  async def root():
models/db_models.py CHANGED
@@ -47,9 +47,11 @@ class Podcast(Base):
47
 
48
  id = Column(Integer, primary_key=True, index=True)
49
  title = Column(Unicode(255))
50
- s3_key = Column(String(512), nullable=False)
51
- s3_url = Column(String(1024), nullable=False)
52
- script = Column(UnicodeText)
 
 
53
  user_id = Column(Integer, ForeignKey("users.id"))
54
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
55
  created_at = Column(DateTime(timezone=True), server_default=func.now())
@@ -57,6 +59,11 @@ class Podcast(Base):
57
  owner = relationship("User", back_populates="podcasts")
58
  source = relationship("Source", back_populates="podcasts")
59
 
 
 
 
 
 
60
  class FlashcardSet(Base):
61
  __tablename__ = "flashcard_sets"
62
 
 
47
 
48
  id = Column(Integer, primary_key=True, index=True)
49
  title = Column(Unicode(255))
50
+ s3_key = Column(String(512), nullable=True)
51
+ s3_url = Column(String(1024), nullable=True)
52
+ script = Column(UnicodeText, nullable=True)
53
+ status = Column(String(50), default="processing") # pending, processing, completed, failed
54
+ error_message = Column(UnicodeText, nullable=True)
55
  user_id = Column(Integer, ForeignKey("users.id"))
56
  source_id = Column(Integer, ForeignKey("sources.id"), nullable=True)
57
  created_at = Column(DateTime(timezone=True), server_default=func.now())
 
59
  owner = relationship("User", back_populates="podcasts")
60
  source = relationship("Source", back_populates="podcasts")
61
 
62
+ @property
63
+ def parent_file_key(self):
64
+ return self.source.s3_key if self.source else None
65
+
66
+
67
  class FlashcardSet(Base):
68
  __tablename__ = "flashcard_sets"
69
 
models/schemas.py CHANGED
@@ -44,16 +44,11 @@ class SourceFileResponse(BaseModel):
44
  from_attributes = True
45
 
46
  # Podcast Schemas
47
- class PodcastAnalyzeRequest(BaseModel):
48
- file_key: str
49
- duration_minutes: int = 10
50
-
51
  class PodcastGenerateRequest(BaseModel):
52
  user_prompt: str
53
  model: str = "gpt-4o"
54
  duration_minutes: int = 10
55
  podcast_format: str = "deep dive"
56
- pdf_suggestions: str = ""
57
  file_key: Optional[str] = None
58
  tts_model: str = "gemini-2.5-flash-preview-tts"
59
  spk1_voice: str = "Zephyr"
@@ -61,6 +56,22 @@ class PodcastGenerateRequest(BaseModel):
61
  bgm_choice: str = "No BGM"
62
  temperature: float = 1.0
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Flashcard Schemas
65
  class FlashcardItem(BaseModel):
66
  question: str
 
44
  from_attributes = True
45
 
46
  # Podcast Schemas
 
 
 
 
47
  class PodcastGenerateRequest(BaseModel):
48
  user_prompt: str
49
  model: str = "gpt-4o"
50
  duration_minutes: int = 10
51
  podcast_format: str = "deep dive"
 
52
  file_key: Optional[str] = None
53
  tts_model: str = "gemini-2.5-flash-preview-tts"
54
  spk1_voice: str = "Zephyr"
 
56
  bgm_choice: str = "No BGM"
57
  temperature: float = 1.0
58
 
59
+ class PodcastResponse(BaseModel):
60
+ id: int
61
+ title: Optional[str]
62
+ s3_key: Optional[str]
63
+ s3_url: Optional[str]
64
+ script: Optional[str]
65
+ status: str = "completed"
66
+ error_message: Optional[str]
67
+ parent_file_id: Optional[int] = None
68
+ parent_file_key: Optional[str] = None
69
+ created_at: datetime
70
+
71
+ class Config:
72
+ from_attributes = True
73
+
74
+
75
  # Flashcard Schemas
76
  class FlashcardItem(BaseModel):
77
  question: str
services/flashcard_service.py CHANGED
@@ -1,8 +1,9 @@
1
  import json
2
  import logging
3
  import os
 
4
  import tempfile
5
- from typing import List, Dict, Optional, Any
6
  import openai
7
  from botocore.exceptions import ClientError
8
 
@@ -23,34 +24,57 @@ class FlashcardService:
23
  difficulty: str = "medium",
24
  quantity: str = "standard",
25
  topic: Optional[str] = None,
26
- language: str = "English"
 
27
  ) -> List[Dict[str, str]]:
28
  """
29
- Generates flashcards from either an S3 PDF or direct text input (Original File-ID Method).
 
 
 
 
30
  """
31
  try:
 
 
 
32
  system_prompt = get_flashcard_system_prompt(difficulty, quantity, language)
33
  if topic:
34
  system_prompt += get_flashcard_topic_prompt(topic)
35
 
36
  if file_key:
37
- # Download PDF from S3
 
 
 
38
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
39
  tmp_path = tmp.name
40
  tmp.close()
41
 
42
  try:
43
- s3_service.s3_client.download_file(
 
 
44
  settings.AWS_S3_BUCKET,
45
  file_key,
46
  tmp_path
47
  )
48
 
49
- with open(tmp_path, "rb") as f:
50
- uploaded_file = self.openai_client.files.create(
51
- file=f,
52
- purpose="assistants"
53
- )
 
 
 
 
 
 
 
 
 
 
54
 
55
  messages = [
56
  {"role": "system", "content": system_prompt},
@@ -65,26 +89,41 @@ class FlashcardService:
65
  }
66
  ]
67
 
68
- response = self.openai_client.chat.completions.create(
 
 
69
  model="gpt-4o-mini",
70
  messages=messages,
71
  temperature=0.7
72
  )
73
 
74
- # Clean up OpenAI file
75
- self.openai_client.files.delete(uploaded_file.id)
 
 
 
 
 
 
76
  raw_content = response.choices[0].message.content
77
 
78
  finally:
 
79
  if os.path.exists(tmp_path):
80
- os.remove(tmp_path)
81
 
82
  elif text_input:
 
 
 
83
  messages = [
84
  {"role": "system", "content": system_prompt},
85
  {"role": "user", "content": text_input}
86
  ]
87
- response = self.openai_client.chat.completions.create(
 
 
 
88
  model="gpt-4o-mini",
89
  messages=messages,
90
  temperature=0.7
@@ -94,6 +133,9 @@ class FlashcardService:
94
  else:
95
  raise ValueError("Either file_key or text_input must be provided")
96
 
 
 
 
97
  # Parse JSON
98
  if "```json" in raw_content:
99
  raw_content = raw_content.split("```json")[1].split("```")[0].strip()
@@ -109,6 +151,7 @@ class FlashcardService:
109
  async def generate_explanation(self, question: str, file_key: Optional[str] = None, language: str = "English") -> str:
110
  """
111
  Generates a detailed explanation for a flashcard question.
 
112
  """
113
  try:
114
  explanation_prompt = get_flashcard_explanation_prompt(question, language)
@@ -119,33 +162,53 @@ class FlashcardService:
119
  tmp.close()
120
 
121
  try:
122
- s3_service.s3_client.download_file(
 
 
123
  settings.AWS_S3_BUCKET,
124
  file_key,
125
  tmp_path
126
  )
127
- with open(tmp_path, "rb") as f:
128
- uploaded_file = self.openai_client.files.create(file=f, purpose="assistants")
 
 
 
 
 
129
 
130
  messages = [
131
  {"role": "system", "content": explanation_prompt},
132
  {"role": "user", "content": [{"type": "file", "file": {"file_id": uploaded_file.id}}]}
133
  ]
134
- response = self.openai_client.chat.completions.create(
 
 
 
135
  model="gpt-4o-mini",
136
  messages=messages
137
  )
138
- self.openai_client.files.delete(uploaded_file.id)
 
 
 
 
 
 
139
  return response.choices[0].message.content
140
  finally:
 
141
  if os.path.exists(tmp_path):
142
- os.remove(tmp_path)
143
  else:
144
  messages = [
145
  {"role": "system", "content": explanation_prompt},
146
  {"role": "user", "content": f"Please explain the question: {question}"}
147
  ]
148
- response = self.openai_client.chat.completions.create(
 
 
 
149
  model="gpt-4o-mini",
150
  messages=messages
151
  )
 
1
  import json
2
  import logging
3
  import os
4
+ import asyncio
5
  import tempfile
6
+ from typing import List, Dict, Optional, Any, Callable
7
  import openai
8
  from botocore.exceptions import ClientError
9
 
 
24
  difficulty: str = "medium",
25
  quantity: str = "standard",
26
  topic: Optional[str] = None,
27
+ language: str = "English",
28
+ progress_callback: Optional[Callable[[int, str], None]] = None
29
  ) -> List[Dict[str, str]]:
30
  """
31
+ Generates flashcards from either an S3 PDF or direct text input.
32
+ Uses asyncio.to_thread for all blocking I/O operations to enable parallel execution.
33
+
34
+ Args:
35
+ progress_callback: Optional callback function(progress: int, message: str) for progress updates
36
  """
37
  try:
38
+ if progress_callback:
39
+ progress_callback(5, "Preparing prompts...")
40
+
41
  system_prompt = get_flashcard_system_prompt(difficulty, quantity, language)
42
  if topic:
43
  system_prompt += get_flashcard_topic_prompt(topic)
44
 
45
  if file_key:
46
+ if progress_callback:
47
+ progress_callback(15, "Downloading file from S3...")
48
+
49
+ # Download PDF from S3 (non-blocking)
50
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
51
  tmp_path = tmp.name
52
  tmp.close()
53
 
54
  try:
55
+ # Use asyncio.to_thread for S3 download
56
+ await asyncio.to_thread(
57
+ s3_service.s3_client.download_file,
58
  settings.AWS_S3_BUCKET,
59
  file_key,
60
  tmp_path
61
  )
62
 
63
+ if progress_callback:
64
+ progress_callback(30, "Uploading to OpenAI...")
65
+
66
+ # Read file and upload to OpenAI (non-blocking)
67
+ def upload_to_openai():
68
+ with open(tmp_path, "rb") as f:
69
+ return self.openai_client.files.create(
70
+ file=f,
71
+ purpose="assistants"
72
+ )
73
+
74
+ uploaded_file = await asyncio.to_thread(upload_to_openai)
75
+
76
+ if progress_callback:
77
+ progress_callback(45, "Generating flashcards with AI...")
78
 
79
  messages = [
80
  {"role": "system", "content": system_prompt},
 
89
  }
90
  ]
91
 
92
+ # Call OpenAI API (non-blocking)
93
+ response = await asyncio.to_thread(
94
+ self.openai_client.chat.completions.create,
95
  model="gpt-4o-mini",
96
  messages=messages,
97
  temperature=0.7
98
  )
99
 
100
+ if progress_callback:
101
+ progress_callback(75, "Cleaning up...")
102
+
103
+ # Clean up OpenAI file (non-blocking)
104
+ await asyncio.to_thread(
105
+ self.openai_client.files.delete,
106
+ uploaded_file.id
107
+ )
108
  raw_content = response.choices[0].message.content
109
 
110
  finally:
111
+ # Remove temp file (non-blocking)
112
  if os.path.exists(tmp_path):
113
+ await asyncio.to_thread(os.remove, tmp_path)
114
 
115
  elif text_input:
116
+ if progress_callback:
117
+ progress_callback(20, "Generating flashcards with AI...")
118
+
119
  messages = [
120
  {"role": "system", "content": system_prompt},
121
  {"role": "user", "content": text_input}
122
  ]
123
+
124
+ # Call OpenAI API (non-blocking)
125
+ response = await asyncio.to_thread(
126
+ self.openai_client.chat.completions.create,
127
  model="gpt-4o-mini",
128
  messages=messages,
129
  temperature=0.7
 
133
  else:
134
  raise ValueError("Either file_key or text_input must be provided")
135
 
136
+ if progress_callback:
137
+ progress_callback(85, "Parsing results...")
138
+
139
  # Parse JSON
140
  if "```json" in raw_content:
141
  raw_content = raw_content.split("```json")[1].split("```")[0].strip()
 
151
  async def generate_explanation(self, question: str, file_key: Optional[str] = None, language: str = "English") -> str:
152
  """
153
  Generates a detailed explanation for a flashcard question.
154
+ Uses asyncio.to_thread for all blocking I/O operations.
155
  """
156
  try:
157
  explanation_prompt = get_flashcard_explanation_prompt(question, language)
 
162
  tmp.close()
163
 
164
  try:
165
+ # Download from S3 (non-blocking)
166
+ await asyncio.to_thread(
167
+ s3_service.s3_client.download_file,
168
  settings.AWS_S3_BUCKET,
169
  file_key,
170
  tmp_path
171
  )
172
+
173
+ # Upload to OpenAI (non-blocking)
174
+ def upload_to_openai():
175
+ with open(tmp_path, "rb") as f:
176
+ return self.openai_client.files.create(file=f, purpose="assistants")
177
+
178
+ uploaded_file = await asyncio.to_thread(upload_to_openai)
179
 
180
  messages = [
181
  {"role": "system", "content": explanation_prompt},
182
  {"role": "user", "content": [{"type": "file", "file": {"file_id": uploaded_file.id}}]}
183
  ]
184
+
185
+ # Call OpenAI API (non-blocking)
186
+ response = await asyncio.to_thread(
187
+ self.openai_client.chat.completions.create,
188
  model="gpt-4o-mini",
189
  messages=messages
190
  )
191
+
192
+ # Clean up OpenAI file (non-blocking)
193
+ await asyncio.to_thread(
194
+ self.openai_client.files.delete,
195
+ uploaded_file.id
196
+ )
197
+
198
  return response.choices[0].message.content
199
  finally:
200
+ # Remove temp file (non-blocking)
201
  if os.path.exists(tmp_path):
202
+ await asyncio.to_thread(os.remove, tmp_path)
203
  else:
204
  messages = [
205
  {"role": "system", "content": explanation_prompt},
206
  {"role": "user", "content": f"Please explain the question: {question}"}
207
  ]
208
+
209
+ # Call OpenAI API (non-blocking)
210
+ response = await asyncio.to_thread(
211
+ self.openai_client.chat.completions.create,
212
  model="gpt-4o-mini",
213
  messages=messages
214
  )
services/mindmap_service.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  import os
 
3
  import tempfile
4
  from typing import Optional
5
  import openai
@@ -20,28 +21,34 @@ class MindMapService:
20
  ) -> str:
21
  """
22
  Generates a Mermaid mindmap from either an S3 PDF or direct text input.
 
23
  """
24
  try:
25
  system_prompt = get_mindmap_system_prompt()
26
 
27
  if file_key:
28
- # Download PDF from S3
29
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
30
  tmp_path = tmp.name
31
  tmp.close()
32
 
33
  try:
34
- s3_service.s3_client.download_file(
 
35
  settings.AWS_S3_BUCKET,
36
  file_key,
37
  tmp_path
38
  )
39
 
40
- with open(tmp_path, "rb") as f:
41
- uploaded_file = self.openai_client.files.create(
42
- file=f,
43
- purpose="assistants"
44
- )
 
 
 
 
45
 
46
  messages = [
47
  {"role": "system", "content": system_prompt},
@@ -56,27 +63,35 @@ class MindMapService:
56
  }
57
  ]
58
 
59
- response = self.openai_client.chat.completions.create(
 
 
60
  model="gpt-4o-mini",
61
  messages=messages,
62
  temperature=0.7
63
  )
64
 
65
- # Clean up OpenAI file
66
- self.openai_client.files.delete(uploaded_file.id)
 
 
 
67
 
68
  raw_content = response.choices[0].message.content
69
 
70
  finally:
71
  if os.path.exists(tmp_path):
72
- os.remove(tmp_path)
73
 
74
  elif text_input:
75
  messages = [
76
  {"role": "system", "content": system_prompt},
77
  {"role": "user", "content": text_input}
78
  ]
79
- response = self.openai_client.chat.completions.create(
 
 
 
80
  model="gpt-4o-mini",
81
  messages=messages,
82
  temperature=0.7
@@ -92,12 +107,6 @@ class MindMapService:
92
  elif "```" in raw_content:
93
  raw_content = raw_content.split("```")[1].split("```")[0].strip()
94
 
95
- # Ensure it starts with 'mindmap'
96
- if "mindmap" not in raw_content.lower():
97
- # If the AI missed the header, we might need to handle it,
98
- # but usually the prompt is strong.
99
- pass
100
-
101
  return raw_content.strip()
102
 
103
  except Exception as e:
 
1
  import logging
2
  import os
3
+ import asyncio
4
  import tempfile
5
  from typing import Optional
6
  import openai
 
21
  ) -> str:
22
  """
23
  Generates a Mermaid mindmap from either an S3 PDF or direct text input.
24
+ Uses asyncio.to_thread for all blocking I/O operations.
25
  """
26
  try:
27
  system_prompt = get_mindmap_system_prompt()
28
 
29
  if file_key:
30
+ # Download PDF from S3 (non-blocking)
31
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
32
  tmp_path = tmp.name
33
  tmp.close()
34
 
35
  try:
36
+ await asyncio.to_thread(
37
+ s3_service.s3_client.download_file,
38
  settings.AWS_S3_BUCKET,
39
  file_key,
40
  tmp_path
41
  )
42
 
43
+ # Upload to OpenAI (non-blocking)
44
+ def upload_to_openai():
45
+ with open(tmp_path, "rb") as f:
46
+ return self.openai_client.files.create(
47
+ file=f,
48
+ purpose="assistants"
49
+ )
50
+
51
+ uploaded_file = await asyncio.to_thread(upload_to_openai)
52
 
53
  messages = [
54
  {"role": "system", "content": system_prompt},
 
63
  }
64
  ]
65
 
66
+ # Call OpenAI (non-blocking)
67
+ response = await asyncio.to_thread(
68
+ self.openai_client.chat.completions.create,
69
  model="gpt-4o-mini",
70
  messages=messages,
71
  temperature=0.7
72
  )
73
 
74
+ # Clean up OpenAI file (non-blocking)
75
+ await asyncio.to_thread(
76
+ self.openai_client.files.delete,
77
+ uploaded_file.id
78
+ )
79
 
80
  raw_content = response.choices[0].message.content
81
 
82
  finally:
83
  if os.path.exists(tmp_path):
84
+ await asyncio.to_thread(os.remove, tmp_path)
85
 
86
  elif text_input:
87
  messages = [
88
  {"role": "system", "content": system_prompt},
89
  {"role": "user", "content": text_input}
90
  ]
91
+
92
+ # Call OpenAI (non-blocking)
93
+ response = await asyncio.to_thread(
94
+ self.openai_client.chat.completions.create,
95
  model="gpt-4o-mini",
96
  messages=messages,
97
  temperature=0.7
 
107
  elif "```" in raw_content:
108
  raw_content = raw_content.split("```")[1].split("```")[0].strip()
109
 
 
 
 
 
 
 
110
  return raw_content.strip()
111
 
112
  except Exception as e:
services/podcast_service.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import json
4
  import time
5
  import struct
 
6
  import logging
7
  import mimetypes
8
  from datetime import datetime
@@ -50,26 +51,35 @@ class PodcastService:
50
  elif duration_minutes <= 15: return 4000
51
  else: return 5000
52
 
53
- async def analyze_pdf(self, file_key: str, duration_minutes: int, model: str = "gpt-4o"):
54
  # 1. Get file from S3
55
  # Since openai files.create needs a file, we download it temporarily
56
  temp_path = f"temp_{int(time.time())}.pdf"
57
  try:
58
  import boto3
59
- s3 = boto3.client('s3',
60
- aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
61
- aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
62
- region_name=settings.AWS_REGION)
63
- s3.download_file(settings.AWS_S3_BUCKET, file_key, temp_path)
 
 
 
 
 
64
 
65
- # 2. Upload to OpenAI
66
- with open(temp_path, "rb") as f:
67
- file_response = self.openai_client.files.create(file=f, purpose="assistants")
 
68
 
69
- # 3. Analyze
 
 
70
  formatted_prompt = ANALYSIS_PROMPT.format(duration_minutes=duration_minutes)
71
 
72
- response = self.openai_client.chat.completions.parse(
 
73
  model=model,
74
  messages=[
75
  {"role": "system", "content": formatted_prompt},
@@ -81,10 +91,11 @@ class PodcastService:
81
  return response.choices[0].message.content
82
  finally:
83
  if os.path.exists(temp_path):
84
- os.remove(temp_path)
 
85
 
86
  async def generate_script(self, user_prompt: str, model: str, duration_minutes: int,
87
- podcast_format: str, pdf_suggestions: str, file_key: Optional[str] = None):
88
  target_words = self.compute_script_targets(duration_minutes)
89
  formatted_system = SYSTEM_PROMPT.format(
90
  target_words=target_words,
@@ -97,15 +108,24 @@ class PodcastService:
97
  temp_path = None
98
  if file_key:
99
  temp_path = f"temp_gen_{int(time.time())}.pdf"
100
- import boto3
101
- s3 = boto3.client('s3',
102
- aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
103
- aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
104
- region_name=settings.AWS_REGION)
105
- s3.download_file(settings.AWS_S3_BUCKET, file_key, temp_path)
106
 
107
- with open(temp_path, "rb") as f:
108
- file_response = self.openai_client.files.create(file=f, purpose="assistants")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  messages.append({
111
  "role": "user",
@@ -118,7 +138,9 @@ class PodcastService:
118
  messages.append({"role": "user", "content": user_prompt})
119
 
120
  try:
121
- response = self.openai_client.chat.completions.create(
 
 
122
  model=model,
123
  messages=messages,
124
  temperature=1.0,
@@ -127,14 +149,37 @@ class PodcastService:
127
  return response.choices[0].message.content
128
  finally:
129
  if temp_path and os.path.exists(temp_path):
130
- os.remove(temp_path)
 
131
 
132
  def parse_script(self, script: str) -> List[Tuple[str, str]]:
133
  dialogs = []
134
- pattern = re.compile(r"^(Speaker [12])[::]\s*(.*)$", re.MULTILINE)
135
- for match in pattern.finditer(script):
136
- speaker, text = match.groups()
137
- dialogs.append((speaker, text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return dialogs
139
 
140
  def split_script(self, dialogs: List[Tuple[str, str]], chunk_size=20) -> List[str]:
@@ -144,9 +189,12 @@ class PodcastService:
144
  chunks.append("\n".join([f"{s}: {t}" for s, t in chunk]))
145
  return chunks
146
 
147
- def generate_audio_chunk(self, chunk_script: str, tts_model: str, spk1_voice: str,
148
  spk2_voice: str, temperature: float, index: int) -> Optional[str]:
149
  try:
 
 
 
150
  contents = [types.Content(role="user", parts=[types.Part.from_text(text=chunk_script)])]
151
  config = types.GenerateContentConfig(
152
  temperature=temperature,
@@ -163,27 +211,37 @@ class PodcastService:
163
  )
164
  )
165
 
 
166
  audio_data = None
167
  mime_type = "audio/wav"
168
- for chunk in self.genai_client.models.generate_content_stream(model=tts_model, contents=contents, config=config):
 
 
169
  if chunk.candidates and chunk.candidates[0].content.parts:
170
  part = chunk.candidates[0].content.parts[0]
171
  if part.inline_data:
172
  audio_data = part.inline_data.data
173
  mime_type = part.inline_data.mime_type
 
174
  break
175
 
176
  if audio_data:
177
  # Basic WAV conversion if needed (simplified from original)
178
  if "wav" not in mime_type.lower():
 
 
179
  # We usually get raw PCM or similar, need header
180
  audio_data = self._convert_to_wav(audio_data, mime_type)
181
 
182
  path = f"chunk_{index}_{int(time.time())}.wav"
183
  with open(path, "wb") as f:
184
  f.write(audio_data)
 
185
  return path
 
 
186
  except Exception as e:
 
187
  logger.error(f"Error generating chunk {index}: {e}")
188
  return None
189
 
@@ -201,49 +259,72 @@ class PodcastService:
201
  return header + audio_data
202
 
203
  async def generate_full_audio(self, script: str, tts_model: str, spk1_voice: str,
204
- spk2_voice: str, temperature: float, bgm_choice: str):
 
 
205
  dialogs = self.parse_script(script)
 
 
206
  chunks = self.split_script(dialogs)
 
207
 
208
- chunk_paths = [None] * len(chunks)
209
- with ThreadPoolExecutor(max_workers=4) as executor:
210
- futures = {executor.submit(self.generate_audio_chunk, chunks[i], tts_model, spk1_voice, spk2_voice, temperature, i): i for i in range(len(chunks))}
211
- for future in as_completed(futures):
212
- idx = futures[future]
213
- chunk_paths[idx] = future.result()
 
 
214
 
215
- valid_paths = [p for p in chunk_paths if p]
216
- if not valid_paths: return None
217
 
218
- # Combine
219
- combined = AudioSegment.empty()
220
- for p in valid_paths:
221
- combined += AudioSegment.from_file(p)
222
- combined += AudioSegment.silent(duration=500)
223
- os.remove(p)
224
 
225
- final_path = f"final_podcast_{int(time.time())}.wav"
 
 
226
 
227
- # Mix BGM
228
- bgm_path = BGM_CHOICES.get(bgm_choice)
229
- if bgm_path and os.path.exists(bgm_path):
230
- bgm = AudioSegment.from_file(bgm_path)
231
- # Simple mix: loop BGM, fade in/out
232
- if len(bgm) < len(combined) + 10000:
233
- bgm = bgm * ( (len(combined) + 10000) // len(bgm) + 1 )
234
-
235
- bgm = bgm[:len(combined) + 10000]
236
- bgm_main = bgm[5000:5000+len(combined)] - 16
237
- bgm_intro = bgm[:5000]
238
- bgm_outro = bgm[5000+len(combined):].fade_out(5000) - 16
239
-
240
- bgm_processed = bgm_intro + bgm_main + bgm_outro
241
- combined_with_intro = AudioSegment.silent(duration=5000) + combined + AudioSegment.silent(duration=5000)
242
- final_audio = combined_with_intro.overlay(bgm_processed)
243
- final_audio.export(final_path, format="wav")
244
- else:
245
- combined.export(final_path, format="wav")
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  return final_path
248
 
 
 
249
  podcast_service = PodcastService()
 
3
  import json
4
  import time
5
  import struct
6
+ import asyncio
7
  import logging
8
  import mimetypes
9
  from datetime import datetime
 
51
  elif duration_minutes <= 15: return 4000
52
  else: return 5000
53
 
54
+ async def analyze_pdf(self, file_key: str, duration_minutes: int, model: str = "gpt-4o", progress_callback=None):
55
  # 1. Get file from S3
56
  # Since openai files.create needs a file, we download it temporarily
57
  temp_path = f"temp_{int(time.time())}.pdf"
58
  try:
59
  import boto3
60
+
61
+ # Create S3 client and download (non-blocking)
62
+ def download_from_s3():
63
+ s3 = boto3.client('s3',
64
+ aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
65
+ aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
66
+ region_name=settings.AWS_REGION)
67
+ s3.download_file(settings.AWS_S3_BUCKET, file_key, temp_path)
68
+
69
+ await asyncio.to_thread(download_from_s3)
70
 
71
+ # 2. Upload to OpenAI (non-blocking)
72
+ def upload_to_openai():
73
+ with open(temp_path, "rb") as f:
74
+ return self.openai_client.files.create(file=f, purpose="assistants")
75
 
76
+ file_response = await asyncio.to_thread(upload_to_openai)
77
+
78
+ # 3. Analyze (non-blocking)
79
  formatted_prompt = ANALYSIS_PROMPT.format(duration_minutes=duration_minutes)
80
 
81
+ response = await asyncio.to_thread(
82
+ self.openai_client.chat.completions.parse,
83
  model=model,
84
  messages=[
85
  {"role": "system", "content": formatted_prompt},
 
91
  return response.choices[0].message.content
92
  finally:
93
  if os.path.exists(temp_path):
94
+ await asyncio.to_thread(os.remove, temp_path)
95
+
96
 
97
  async def generate_script(self, user_prompt: str, model: str, duration_minutes: int,
98
+ podcast_format: str, pdf_suggestions: str, file_key: Optional[str] = None, progress_callback=None):
99
  target_words = self.compute_script_targets(duration_minutes)
100
  formatted_system = SYSTEM_PROMPT.format(
101
  target_words=target_words,
 
108
  temp_path = None
109
  if file_key:
110
  temp_path = f"temp_gen_{int(time.time())}.pdf"
 
 
 
 
 
 
111
 
112
+ # Download from S3 (non-blocking)
113
+ def download_from_s3():
114
+ import boto3
115
+ s3 = boto3.client('s3',
116
+ aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
117
+ aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
118
+ region_name=settings.AWS_REGION)
119
+ s3.download_file(settings.AWS_S3_BUCKET, file_key, temp_path)
120
+
121
+ await asyncio.to_thread(download_from_s3)
122
+
123
+ # Upload to OpenAI (non-blocking)
124
+ def upload_to_openai():
125
+ with open(temp_path, "rb") as f:
126
+ return self.openai_client.files.create(file=f, purpose="assistants")
127
+
128
+ file_response = await asyncio.to_thread(upload_to_openai)
129
 
130
  messages.append({
131
  "role": "user",
 
138
  messages.append({"role": "user", "content": user_prompt})
139
 
140
  try:
141
+ # Call OpenAI API (non-blocking)
142
+ response = await asyncio.to_thread(
143
+ self.openai_client.chat.completions.create,
144
  model=model,
145
  messages=messages,
146
  temperature=1.0,
 
149
  return response.choices[0].message.content
150
  finally:
151
  if temp_path and os.path.exists(temp_path):
152
+ await asyncio.to_thread(os.remove, temp_path)
153
+
154
 
155
  def parse_script(self, script: str) -> List[Tuple[str, str]]:
156
  dialogs = []
157
+
158
+ # Try English format: "Speaker 1:" or "**Speaker 1**:"
159
+ pattern_en = re.compile(r"^\*?\*?(Speaker [12])\*?\*?[::]\s*(.*)$", re.MULTILINE)
160
+ matches = list(pattern_en.finditer(script))
161
+
162
+ if matches:
163
+ print(f"[DEBUG] Found {len(matches)} English patterns")
164
+ for match in matches:
165
+ speaker, text = match.groups()
166
+ dialogs.append((speaker, text))
167
+ else:
168
+ # Try Japanese format: "スピーカー1:"
169
+ pattern_jp = re.compile(r"^\*?\*?(スピーカー[12])\*?\*?[::]\s*(.*)$", re.MULTILINE)
170
+ matches = list(pattern_jp.finditer(script))
171
+
172
+ if matches:
173
+ print(f"[DEBUG] Found {len(matches)} Japanese patterns")
174
+ for match in matches:
175
+ speaker_jp, text = match.groups()
176
+ speaker_num = "1" if "1" in speaker_jp else "2"
177
+ speaker = f"Speaker {speaker_num}"
178
+ dialogs.append((speaker, text))
179
+ else:
180
+ print(f"[ERROR] No patterns found!")
181
+ print(f"[DEBUG] Preview: {script[:300]}")
182
+
183
  return dialogs
184
 
185
  def split_script(self, dialogs: List[Tuple[str, str]], chunk_size=20) -> List[str]:
 
189
  chunks.append("\n".join([f"{s}: {t}" for s, t in chunk]))
190
  return chunks
191
 
192
+ async def generate_audio_chunk(self, chunk_script: str, tts_model: str, spk1_voice: str,
193
  spk2_voice: str, temperature: float, index: int) -> Optional[str]:
194
  try:
195
+ print(f"[DEBUG] Chunk {index}: Starting generation")
196
+ print(f"[DEBUG] Chunk {index}: Script length: {len(chunk_script)} chars")
197
+
198
  contents = [types.Content(role="user", parts=[types.Part.from_text(text=chunk_script)])]
199
  config = types.GenerateContentConfig(
200
  temperature=temperature,
 
211
  )
212
  )
213
 
214
+ print(f"[DEBUG] Chunk {index}: Calling Gemini API (Async)...")
215
  audio_data = None
216
  mime_type = "audio/wav"
217
+
218
+ # Use client.aio for non-blocking network I/O
219
+ async for chunk in await self.genai_client.aio.models.generate_content_stream(model=tts_model, contents=contents, config=config):
220
  if chunk.candidates and chunk.candidates[0].content.parts:
221
  part = chunk.candidates[0].content.parts[0]
222
  if part.inline_data:
223
  audio_data = part.inline_data.data
224
  mime_type = part.inline_data.mime_type
225
+ print(f"[DEBUG] Chunk {index}: Received audio data, mime: {mime_type}")
226
  break
227
 
228
  if audio_data:
229
  # Basic WAV conversion if needed (simplified from original)
230
  if "wav" not in mime_type.lower():
231
+ print(f"[DEBUG] Chunk {index}: Converting to WAV")
232
+
233
  # We usually get raw PCM or similar, need header
234
  audio_data = self._convert_to_wav(audio_data, mime_type)
235
 
236
  path = f"chunk_{index}_{int(time.time())}.wav"
237
  with open(path, "wb") as f:
238
  f.write(audio_data)
239
+ print(f"[DEBUG] Chunk {index}: Saved to {path}")
240
  return path
241
+ else:
242
+ print(f"[ERROR] Chunk {index}: No audio data received from Gemini")
243
  except Exception as e:
244
+ print(f"[ERROR] Chunk {index}: Exception: {e}")
245
  logger.error(f"Error generating chunk {index}: {e}")
246
  return None
247
 
 
259
  return header + audio_data
260
 
261
  async def generate_full_audio(self, script: str, tts_model: str, spk1_voice: str,
262
+ spk2_voice: str, temperature: float, bgm_choice: str, progress_callback=None):
263
+ print(f"[DEBUG] Starting generate_full_audio")
264
+
265
  dialogs = self.parse_script(script)
266
+ print(f"[DEBUG] Parsed {len(dialogs)} dialogs")
267
+
268
  chunks = self.split_script(dialogs)
269
+ print(f"[DEBUG] Split into {len(chunks)} chunks")
270
 
271
+ # Run chunks in parallel using asyncio.gather
272
+ print(f"[DEBUG] Starting parallel chunk generation...")
273
+ tasks = []
274
+ for i, chunk_script in enumerate(chunks):
275
+ # Now calling the async method directly
276
+ tasks.append(self.generate_audio_chunk(
277
+ chunk_script, tts_model, spk1_voice, spk2_voice, temperature, i
278
+ ))
279
 
280
+ chunk_paths = await asyncio.gather(*tasks)
281
+
282
 
283
+ valid_paths = [p for p in chunk_paths if p]
284
+ print(f"[DEBUG] Valid chunks: {len(valid_paths)} out of {len(chunk_paths)}")
 
 
 
 
285
 
286
+ if not valid_paths:
287
+ print(f"[ERROR] No valid audio chunks generated!")
288
+ return None
289
 
290
+ # Combine - This is heavy processing, run in thread
291
+ def combine_audio():
292
+ print(f"[DEBUG] Starting audio combination in thread")
293
+ combined = AudioSegment.empty()
294
+ for i, p in enumerate(valid_paths):
295
+ combined += AudioSegment.from_file(p)
296
+ combined += AudioSegment.silent(duration=500)
297
+ try: os.remove(p)
298
+ except: pass
 
 
 
 
 
 
 
 
 
 
299
 
300
+ final_path = f"final_podcast_{int(time.time())}.wav"
301
+
302
+ # Mix BGM
303
+ bgm_path = BGM_CHOICES.get(bgm_choice)
304
+ if bgm_path and os.path.exists(bgm_path):
305
+ print(f"[DEBUG] Adding BGM: {bgm_choice}")
306
+ bgm = AudioSegment.from_file(bgm_path)
307
+ if len(bgm) < len(combined) + 10000:
308
+ bgm = bgm * ( (len(combined) + 10000) // len(bgm) + 1 )
309
+
310
+ bgm = bgm[:len(combined) + 10000]
311
+ bgm_main = bgm[5000:5000+len(combined)] - 16
312
+ bgm_intro = bgm[:5000]
313
+ bgm_outro = bgm[5000+len(combined):].fade_out(5000) - 16
314
+
315
+ bgm_processed = bgm_intro + bgm_main + bgm_outro
316
+ combined_with_intro = AudioSegment.silent(duration=5000) + combined + AudioSegment.silent(duration=5000)
317
+ final_audio = combined_with_intro.overlay(bgm_processed)
318
+ final_audio.export(final_path, format="wav")
319
+ else:
320
+ combined.export(final_path, format="wav")
321
+
322
+ return final_path
323
+
324
+ final_path = await asyncio.to_thread(combine_audio)
325
+ print(f"[DEBUG] Audio generation complete: {final_path}")
326
  return final_path
327
 
328
+
329
+
330
  podcast_service = PodcastService()
services/quiz_service.py CHANGED
@@ -1,8 +1,9 @@
1
  import json
2
  import logging
3
  import os
 
4
  import tempfile
5
- from typing import List, Dict, Optional, Any
6
  import openai
7
 
8
  from core.config import settings
@@ -22,12 +23,17 @@ class QuizService:
22
  difficulty: str = "medium",
23
  topic: Optional[str] = None,
24
  language: str = "English",
25
- count_mode: str = "STANDARD"
 
26
  ) -> List[Dict[str, Any]]:
27
  """
28
  Generates a quiz from either an S3 PDF or direct text input.
 
29
  """
30
  try:
 
 
 
31
  # Map count mode to actual numbers
32
  counts = {
33
  "FEWER": "5-10",
@@ -39,23 +45,37 @@ class QuizService:
39
  system_prompt = get_quiz_system_prompt(language).replace("{NUM_QUESTIONS}", num_range)
40
 
41
  if file_key:
42
- # Download PDF from S3
 
 
 
43
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
44
  tmp_path = tmp.name
45
  tmp.close()
46
 
47
  try:
48
- s3_service.s3_client.download_file(
 
49
  settings.AWS_S3_BUCKET,
50
  file_key,
51
  tmp_path
52
  )
53
 
54
- with open(tmp_path, "rb") as f:
55
- uploaded_file = self.openai_client.files.create(
56
- file=f,
57
- purpose="assistants"
58
- )
 
 
 
 
 
 
 
 
 
 
59
 
60
  user_message = f"Analyze the PDF and create {num_range} questions. Difficulty: {difficulty}."
61
  if topic:
@@ -75,21 +95,33 @@ class QuizService:
75
  }
76
  ]
77
 
78
- response = self.openai_client.chat.completions.create(
 
 
79
  model="gpt-4o-mini",
80
  messages=messages,
81
  response_format={"type": "json_object"},
82
  temperature=0.7
83
  )
84
 
85
- self.openai_client.files.delete(uploaded_file.id)
 
 
 
 
 
 
 
86
  raw_content = response.choices[0].message.content
87
 
88
  finally:
89
  if os.path.exists(tmp_path):
90
- os.remove(tmp_path)
91
 
92
  elif text_input:
 
 
 
93
  user_message = f"Analyze the text and create {num_range} questions. Difficulty: {difficulty}."
94
  if topic:
95
  user_message += f" Topic: {topic}."
@@ -99,7 +131,10 @@ class QuizService:
99
  {"role": "system", "content": system_prompt},
100
  {"role": "user", "content": user_message}
101
  ]
102
- response = self.openai_client.chat.completions.create(
 
 
 
103
  model="gpt-4o-mini",
104
  messages=messages,
105
  response_format={"type": "json_object"},
@@ -110,6 +145,9 @@ class QuizService:
110
  else:
111
  raise ValueError("Either file_key or text_input must be provided")
112
 
 
 
 
113
  data = json.loads(raw_content)
114
  # The prompt asks for {"quizzes": [...]}
115
  return data.get("quizzes", [])
 
1
  import json
2
  import logging
3
  import os
4
+ import asyncio
5
  import tempfile
6
+ from typing import List, Dict, Optional, Any, Callable
7
  import openai
8
 
9
  from core.config import settings
 
23
  difficulty: str = "medium",
24
  topic: Optional[str] = None,
25
  language: str = "English",
26
+ count_mode: str = "STANDARD",
27
+ progress_callback: Optional[Callable[[int, str], None]] = None
28
  ) -> List[Dict[str, Any]]:
29
  """
30
  Generates a quiz from either an S3 PDF or direct text input.
31
+ Uses asyncio.to_thread for all blocking I/O operations.
32
  """
33
  try:
34
+ if progress_callback:
35
+ progress_callback(5, "Preparing quiz generation...")
36
+
37
  # Map count mode to actual numbers
38
  counts = {
39
  "FEWER": "5-10",
 
45
  system_prompt = get_quiz_system_prompt(language).replace("{NUM_QUESTIONS}", num_range)
46
 
47
  if file_key:
48
+ if progress_callback:
49
+ progress_callback(15, "Downloading file from S3...")
50
+
51
+ # Download PDF from S3 (non-blocking)
52
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
53
  tmp_path = tmp.name
54
  tmp.close()
55
 
56
  try:
57
+ await asyncio.to_thread(
58
+ s3_service.s3_client.download_file,
59
  settings.AWS_S3_BUCKET,
60
  file_key,
61
  tmp_path
62
  )
63
 
64
+ if progress_callback:
65
+ progress_callback(30, "Uploading to OpenAI...")
66
+
67
+ # Upload to OpenAI (non-blocking)
68
+ def upload_to_openai():
69
+ with open(tmp_path, "rb") as f:
70
+ return self.openai_client.files.create(
71
+ file=f,
72
+ purpose="assistants"
73
+ )
74
+
75
+ uploaded_file = await asyncio.to_thread(upload_to_openai)
76
+
77
+ if progress_callback:
78
+ progress_callback(45, "Generating quiz questions...")
79
 
80
  user_message = f"Analyze the PDF and create {num_range} questions. Difficulty: {difficulty}."
81
  if topic:
 
95
  }
96
  ]
97
 
98
+ # Call OpenAI API (non-blocking)
99
+ response = await asyncio.to_thread(
100
+ self.openai_client.chat.completions.create,
101
  model="gpt-4o-mini",
102
  messages=messages,
103
  response_format={"type": "json_object"},
104
  temperature=0.7
105
  )
106
 
107
+ if progress_callback:
108
+ progress_callback(75, "Cleaning up...")
109
+
110
+ # Clean up (non-blocking)
111
+ await asyncio.to_thread(
112
+ self.openai_client.files.delete,
113
+ uploaded_file.id
114
+ )
115
  raw_content = response.choices[0].message.content
116
 
117
  finally:
118
  if os.path.exists(tmp_path):
119
+ await asyncio.to_thread(os.remove, tmp_path)
120
 
121
  elif text_input:
122
+ if progress_callback:
123
+ progress_callback(20, "Generating quiz questions...")
124
+
125
  user_message = f"Analyze the text and create {num_range} questions. Difficulty: {difficulty}."
126
  if topic:
127
  user_message += f" Topic: {topic}."
 
131
  {"role": "system", "content": system_prompt},
132
  {"role": "user", "content": user_message}
133
  ]
134
+
135
+ # Call OpenAI API (non-blocking)
136
+ response = await asyncio.to_thread(
137
+ self.openai_client.chat.completions.create,
138
  model="gpt-4o-mini",
139
  messages=messages,
140
  response_format={"type": "json_object"},
 
145
  else:
146
  raise ValueError("Either file_key or text_input must be provided")
147
 
148
+ if progress_callback:
149
+ progress_callback(85, "Parsing results...")
150
+
151
  data = json.loads(raw_content)
152
  # The prompt asks for {"quizzes": [...]}
153
  return data.get("quizzes", [])
services/report_service.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import logging
3
  import os
 
4
  import tempfile
5
  from typing import List, Dict, Optional, Any
6
  import openai
@@ -23,28 +24,34 @@ class ReportService:
23
  ) -> List[Dict[str, str]]:
24
  """
25
  Generates 4 AI-suggested report formats based on the content.
 
26
  """
27
  try:
28
  system_prompt = get_report_suggestion_prompt(language)
29
 
30
  if file_key:
31
- # Download PDF from S3
32
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
33
  tmp_path = tmp.name
34
  tmp.close()
35
 
36
  try:
37
- s3_service.s3_client.download_file(
 
38
  settings.AWS_S3_BUCKET,
39
  file_key,
40
  tmp_path
41
  )
42
 
43
- with open(tmp_path, "rb") as f:
44
- uploaded_file = self.openai_client.files.create(
45
- file=f,
46
- purpose="assistants"
47
- )
 
 
 
 
48
 
49
  messages = [
50
  {"role": "system", "content": system_prompt},
@@ -59,26 +66,36 @@ class ReportService:
59
  }
60
  ]
61
 
62
- response = self.openai_client.chat.completions.create(
 
 
63
  model="gpt-4o-mini",
64
  messages=messages,
65
  response_format={"type": "json_object"},
66
  temperature=0.7
67
  )
68
 
69
- self.openai_client.files.delete(uploaded_file.id)
 
 
 
 
 
70
  raw_content = response.choices[0].message.content
71
 
72
  finally:
73
  if os.path.exists(tmp_path):
74
- os.remove(tmp_path)
75
 
76
  elif text_input:
77
  messages = [
78
  {"role": "system", "content": system_prompt},
79
  {"role": "user", "content": f"Analyze this content:\n\n{text_input}"}
80
  ]
81
- response = self.openai_client.chat.completions.create(
 
 
 
82
  model="gpt-4o-mini",
83
  messages=messages,
84
  response_format={"type": "json_object"},
@@ -106,6 +123,7 @@ class ReportService:
106
  ) -> str:
107
  """
108
  Generates a full report based on the selected format.
 
109
  """
110
  try:
111
  base_prompt = get_report_prompt(format_key, custom_prompt or "", language)
@@ -115,7 +133,7 @@ class ReportService:
115
  system_prompt = (
116
  "あなたは日本語でレポートを作成するAIアシスタントです。すべての回答は日本語で書いてください。\n\n"
117
  f"{base_prompt}\n\n"
118
- "重要: レポート全体を日本語で書いてください。回答はマークダウン形式で、適切な見出し、箇書き、構造を使用して読みやすくフォーマットしてください。"
119
  )
120
  else:
121
  system_prompt = (
@@ -125,23 +143,28 @@ class ReportService:
125
  )
126
 
127
  if file_key:
128
- # Download PDF from S3
129
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
130
  tmp_path = tmp.name
131
  tmp.close()
132
 
133
  try:
134
- s3_service.s3_client.download_file(
 
135
  settings.AWS_S3_BUCKET,
136
  file_key,
137
  tmp_path
138
  )
139
 
140
- with open(tmp_path, "rb") as f:
141
- uploaded_file = self.openai_client.files.create(
142
- file=f,
143
- purpose="assistants"
144
- )
 
 
 
 
145
 
146
  messages = [
147
  {"role": "system", "content": system_prompt},
@@ -156,25 +179,35 @@ class ReportService:
156
  }
157
  ]
158
 
159
- response = self.openai_client.chat.completions.create(
 
 
160
  model="gpt-4o-mini",
161
  messages=messages,
162
  temperature=0.7
163
  )
164
 
165
- self.openai_client.files.delete(uploaded_file.id)
 
 
 
 
 
166
  return response.choices[0].message.content
167
 
168
  finally:
169
  if os.path.exists(tmp_path):
170
- os.remove(tmp_path)
171
 
172
  elif text_input:
173
  messages = [
174
  {"role": "system", "content": system_prompt},
175
  {"role": "user", "content": f"Please analyze the following content and generate a report based on it:\n\n{text_input}"}
176
  ]
177
- response = self.openai_client.chat.completions.create(
 
 
 
178
  model="gpt-4o-mini",
179
  messages=messages,
180
  temperature=0.7
 
1
  import json
2
  import logging
3
  import os
4
+ import asyncio
5
  import tempfile
6
  from typing import List, Dict, Optional, Any
7
  import openai
 
24
  ) -> List[Dict[str, str]]:
25
  """
26
  Generates 4 AI-suggested report formats based on the content.
27
+ Uses asyncio.to_thread for all blocking I/O operations.
28
  """
29
  try:
30
  system_prompt = get_report_suggestion_prompt(language)
31
 
32
  if file_key:
33
+ # Download PDF from S3 (non-blocking)
34
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
35
  tmp_path = tmp.name
36
  tmp.close()
37
 
38
  try:
39
+ await asyncio.to_thread(
40
+ s3_service.s3_client.download_file,
41
  settings.AWS_S3_BUCKET,
42
  file_key,
43
  tmp_path
44
  )
45
 
46
+ # Upload to OpenAI (non-blocking)
47
+ def upload_to_openai():
48
+ with open(tmp_path, "rb") as f:
49
+ return self.openai_client.files.create(
50
+ file=f,
51
+ purpose="assistants"
52
+ )
53
+
54
+ uploaded_file = await asyncio.to_thread(upload_to_openai)
55
 
56
  messages = [
57
  {"role": "system", "content": system_prompt},
 
66
  }
67
  ]
68
 
69
+ # Call OpenAI (non-blocking)
70
+ response = await asyncio.to_thread(
71
+ self.openai_client.chat.completions.create,
72
  model="gpt-4o-mini",
73
  messages=messages,
74
  response_format={"type": "json_object"},
75
  temperature=0.7
76
  )
77
 
78
+ # Clean up OpenAI file (non-blocking)
79
+ await asyncio.to_thread(
80
+ self.openai_client.files.delete,
81
+ uploaded_file.id
82
+ )
83
+
84
  raw_content = response.choices[0].message.content
85
 
86
  finally:
87
  if os.path.exists(tmp_path):
88
+ await asyncio.to_thread(os.remove, tmp_path)
89
 
90
  elif text_input:
91
  messages = [
92
  {"role": "system", "content": system_prompt},
93
  {"role": "user", "content": f"Analyze this content:\n\n{text_input}"}
94
  ]
95
+
96
+ # Call OpenAI (non-blocking)
97
+ response = await asyncio.to_thread(
98
+ self.openai_client.chat.completions.create,
99
  model="gpt-4o-mini",
100
  messages=messages,
101
  response_format={"type": "json_object"},
 
123
  ) -> str:
124
  """
125
  Generates a full report based on the selected format.
126
+ Uses asyncio.to_thread for all blocking I/O operations.
127
  """
128
  try:
129
  base_prompt = get_report_prompt(format_key, custom_prompt or "", language)
 
133
  system_prompt = (
134
  "あなたは日本語でレポートを作成するAIアシスタントです。すべての回答は日本語で書いてください。\n\n"
135
  f"{base_prompt}\n\n"
136
+ "重要: レポート全体を日本語で書いてください。回答はマークダウン形式で、適切な見出し、箇書き、構造を使用して読みやすくフォーマットしてください。"
137
  )
138
  else:
139
  system_prompt = (
 
143
  )
144
 
145
  if file_key:
146
+ # Download PDF from S3 (non-blocking)
147
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
148
  tmp_path = tmp.name
149
  tmp.close()
150
 
151
  try:
152
+ await asyncio.to_thread(
153
+ s3_service.s3_client.download_file,
154
  settings.AWS_S3_BUCKET,
155
  file_key,
156
  tmp_path
157
  )
158
 
159
+ # Upload to OpenAI (non-blocking)
160
+ def upload_to_openai():
161
+ with open(tmp_path, "rb") as f:
162
+ return self.openai_client.files.create(
163
+ file=f,
164
+ purpose="assistants"
165
+ )
166
+
167
+ uploaded_file = await asyncio.to_thread(upload_to_openai)
168
 
169
  messages = [
170
  {"role": "system", "content": system_prompt},
 
179
  }
180
  ]
181
 
182
+ # Call OpenAI (non-blocking)
183
+ response = await asyncio.to_thread(
184
+ self.openai_client.chat.completions.create,
185
  model="gpt-4o-mini",
186
  messages=messages,
187
  temperature=0.7
188
  )
189
 
190
+ # Clean up OpenAI (non-blocking)
191
+ await asyncio.to_thread(
192
+ self.openai_client.files.delete,
193
+ uploaded_file.id
194
+ )
195
+
196
  return response.choices[0].message.content
197
 
198
  finally:
199
  if os.path.exists(tmp_path):
200
+ await asyncio.to_thread(os.remove, tmp_path)
201
 
202
  elif text_input:
203
  messages = [
204
  {"role": "system", "content": system_prompt},
205
  {"role": "user", "content": f"Please analyze the following content and generate a report based on it:\n\n{text_input}"}
206
  ]
207
+
208
+ # Call OpenAI (non-blocking)
209
+ response = await asyncio.to_thread(
210
+ self.openai_client.chat.completions.create,
211
  model="gpt-4o-mini",
212
  messages=messages,
213
  temperature=0.7