truegleai commited on
Commit
4d8ba6d
·
verified ·
1 Parent(s): dc2716b

Clean rewrite: ensure_tables() before every DB op

Browse files
Files changed (1) hide show
  1. app.py +35 -57
app.py CHANGED
@@ -13,25 +13,21 @@ def install_deps():
13
  print("✅ Dependencies installed")
14
  install_deps()
15
 
16
- # Now import everything
17
  from fastapi import FastAPI, HTTPException
18
  from fastapi.responses import FileResponse
19
  from fastapi.middleware.cors import CORSMiddleware
20
  from contextlib import asynccontextmanager
21
  from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
22
  from sqlalchemy.orm import DeclarativeBase, relationship
23
- from sqlalchemy import Column, Integer, String, Text, Float, Boolean, DateTime, JSON, Enum, ForeignKey
24
- from sqlalchemy.sql import func, select
25
  from pydantic import BaseModel
26
- from typing import Optional, List
27
  import asyncio
28
 
29
- # === FastAPI App (must be defined before routes!) ===
30
  @asynccontextmanager
31
  async def lifespan(app: FastAPI):
32
- # Create tables at startup using the async engine
33
- async with engine.begin() as conn:
34
- await conn.run_sync(Base.metadata.create_all)
35
  for d in ["/data/frames", "/data/features", "/data/index", "/data/videos"]:
36
  os.makedirs(d, exist_ok=True)
37
  print("🚀 Eye-Dentify API ready!")
@@ -43,18 +39,15 @@ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
43
 
44
  # === Database ===
45
  DATABASE_URL = "sqlite+aiosqlite:///./videosearch.db"
46
- SYNC_DB_URL = "sqlite:///./videosearch.db"
 
47
 
48
  class Base(DeclarativeBase): pass
49
 
50
- # Create tables synchronously at import time
51
- from sqlalchemy import create_engine as create_sync_engine
52
- sync_engine = create_sync_engine(SYNC_DB_URL)
53
- Base.metadata.create_all(bind=sync_engine)
54
- print("✅ Database tables created (sync)")
55
-
56
- engine = create_async_engine(DATABASE_URL, echo=False)
57
- async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
58
 
59
  # === Models ===
60
  class VideoStatus(str, enum.Enum):
@@ -67,8 +60,8 @@ class AnalysisStatus(str, enum.Enum):
67
  class Video(Base):
68
  __tablename__ = "videos"
69
  id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
70
- youtube_id = Column(String(64), nullable=True)
71
- youtube_url = Column(String(512), nullable=True)
72
  title = Column(String(512), default="")
73
  description = Column(Text, default="")
74
  channel = Column(String(256), default="")
@@ -97,13 +90,15 @@ class Analysis(Base):
97
  completed_at = Column(DateTime, nullable=True)
98
  video = relationship("Video", back_populates="analyses")
99
 
100
- async def get_session():
101
- """Get async session, ensuring tables exist first"""
102
- async with engine.begin() as conn:
103
- await conn.run_sync(Base.metadata.create_all)
104
- async with engine.begin() as _c: await _c.run_sync(Base.metadata.create_all)
105
- async with async_session_factory() as session:
106
- yield session
 
 
107
 
108
  # === Serve Frontend ===
109
  @app.get("/")
@@ -117,16 +112,6 @@ class SubmitRequest(BaseModel):
117
  class SearchRequest(BaseModel):
118
  youtube_url: str; top_k: int = 10; threshold: float = 0.5
119
 
120
- def video_to_dict(v):
121
- return {
122
- "id": v.id, "youtube_id": v.youtube_id or "", "youtube_url": v.youtube_url or "",
123
- "title": v.title or "Untitled", "channel": v.channel or "",
124
- "duration": v.duration or 0, "thumbnail_url": v.thumbnail_url or "",
125
- "frames_count": v.frames_count or 0, "features_count": v.features_count or 0,
126
- "status": v.status or "pending",
127
- "created_at": v.created_at.isoformat() if v.created_at else None
128
- }
129
-
130
  # === Routes ===
131
  @app.get("/api/v1/health")
132
  def health():
@@ -134,19 +119,14 @@ def health():
134
 
135
  @app.post("/api/v1/videos/submit")
136
  async def submit_video(req: SubmitRequest):
137
- import traceback
138
  try:
139
- import yt_dlp
140
- except Exception as e:
141
- return {"error": f"yt-dlp import failed: {str(e)}", "trace": traceback.format_exc()}
142
- try:
143
- video_id = str(uuid.uuid4())
144
- video = Video(id=video_id, youtube_url=req.youtube_url, youtube_id="",
145
- title="Processing...", channel="", status=VideoStatus.downloading.value,
146
- description="", frames_count=0, features_count=0)
147
- async with engine.begin() as _c: await _c.run_sync(Base.metadata.create_all)
148
  async with async_session_factory() as session:
149
- session.add(video); await session.commit()
 
150
  with yt_dlp.YoutubeDL({'quiet': True, 'extract_flat': True}) as ydl:
151
  info = ydl.extract_info(req.youtube_url, download=False)
152
  video.title = info.get('title', 'Unknown')
@@ -157,17 +137,17 @@ async def submit_video(req: SubmitRequest):
157
  video.status = VideoStatus.completed.value
158
  video.frames_count = 100
159
  video.features_count = 100
160
- async with engine.begin() as _c: await _c.run_sync(Base.metadata.create_all)
161
  async with async_session_factory() as session:
162
- await session.merge(video); await session.commit()
 
163
  return video_to_dict(video)
164
  except Exception as e:
165
- return {"error": str(e), "trace": traceback.format_exc()}
166
 
167
  @app.get("/api/v1/videos/")
168
  async def list_videos(skip: int = 0, limit: int = 50, status: Optional[str] = None):
169
  try:
170
- async with engine.begin() as _c: await _c.run_sync(Base.metadata.create_all)
171
  async with async_session_factory() as session:
172
  stmt = select(Video).offset(skip).limit(limit)
173
  if status: stmt = stmt.where(Video.status == status)
@@ -179,8 +159,8 @@ async def list_videos(skip: int = 0, limit: int = 50, status: Optional[str] = No
179
 
180
  @app.get("/api/v1/videos/{video_id}")
181
  async def get_video(video_id: str):
182
- async with engine.begin() as _c: await _c.run_sync(Base.metadata.create_all)
183
- async with async_session_factory() as session:
184
  result = await session.execute(select(Video).where(Video.id == video_id))
185
  video = result.scalar_one_or_none()
186
  if not video: raise HTTPException(status_code=404, detail="Not found")
@@ -188,8 +168,8 @@ async def get_video(video_id: str):
188
 
189
  @app.delete("/api/v1/videos/{video_id}")
190
  async def delete_video(video_id: str):
191
- async with engine.begin() as _c: await _c.run_sync(Base.metadata.create_all)
192
- async with async_session_factory() as session:
193
  result = await session.execute(select(Video).where(Video.id == video_id))
194
  video = result.scalar_one_or_none()
195
  if video: await session.delete(video); await session.commit()
@@ -201,7 +181,6 @@ async def search(req: SearchRequest):
201
 
202
  @app.get("/api/v1/index/stats")
203
  async def index_stats():
204
- import os
205
  idx_path = "/data/index/index.faiss"
206
  if os.path.exists(idx_path):
207
  size = os.path.getsize(idx_path)
@@ -224,7 +203,6 @@ async def create_analysis():
224
  async def list_analyses():
225
  return []
226
 
227
- # === Launch ===
228
  if __name__ == "__main__":
229
  import uvicorn
230
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
13
  print("✅ Dependencies installed")
14
  install_deps()
15
 
 
16
  from fastapi import FastAPI, HTTPException
17
  from fastapi.responses import FileResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from contextlib import asynccontextmanager
20
  from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
21
  from sqlalchemy.orm import DeclarativeBase, relationship
22
+ from sqlalchemy import Column, Integer, String, Text, Float, DateTime, JSON, Enum, ForeignKey, select
23
+ from sqlalchemy.sql import func
24
  from pydantic import BaseModel
25
+ from typing import Optional
26
  import asyncio
27
 
28
+ # === FastAPI App ===
29
  @asynccontextmanager
30
  async def lifespan(app: FastAPI):
 
 
 
31
  for d in ["/data/frames", "/data/features", "/data/index", "/data/videos"]:
32
  os.makedirs(d, exist_ok=True)
33
  print("🚀 Eye-Dentify API ready!")
 
39
 
40
  # === Database ===
41
  DATABASE_URL = "sqlite+aiosqlite:///./videosearch.db"
42
+ engine = create_async_engine(DATABASE_URL, echo=False)
43
+ async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
44
 
45
  class Base(DeclarativeBase): pass
46
 
47
+ async def ensure_tables():
48
+ """Ensure tables exist (idempotent, safe to call every time)"""
49
+ async with engine.begin() as conn:
50
+ await conn.run_sync(Base.metadata.create_all)
 
 
 
 
51
 
52
  # === Models ===
53
  class VideoStatus(str, enum.Enum):
 
60
  class Video(Base):
61
  __tablename__ = "videos"
62
  id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
63
+ youtube_id = Column(String(64), nullable=True, default="")
64
+ youtube_url = Column(String(512), nullable=True, default="")
65
  title = Column(String(512), default="")
66
  description = Column(Text, default="")
67
  channel = Column(String(256), default="")
 
90
  completed_at = Column(DateTime, nullable=True)
91
  video = relationship("Video", back_populates="analyses")
92
 
93
+ def video_to_dict(v):
94
+ return {
95
+ "id": v.id, "youtube_id": v.youtube_id or "", "youtube_url": v.youtube_url or "",
96
+ "title": v.title or "Untitled", "channel": v.channel or "",
97
+ "duration": v.duration or 0, "thumbnail_url": v.thumbnail_url or "",
98
+ "frames_count": v.frames_count or 0, "features_count": v.features_count or 0,
99
+ "status": v.status or "pending",
100
+ "created_at": v.created_at.isoformat() if v.created_at else None
101
+ }
102
 
103
  # === Serve Frontend ===
104
  @app.get("/")
 
112
  class SearchRequest(BaseModel):
113
  youtube_url: str; top_k: int = 10; threshold: float = 0.5
114
 
 
 
 
 
 
 
 
 
 
 
115
  # === Routes ===
116
  @app.get("/api/v1/health")
117
  def health():
 
119
 
120
  @app.post("/api/v1/videos/submit")
121
  async def submit_video(req: SubmitRequest):
122
+ import yt_dlp, traceback
123
  try:
124
+ await ensure_tables()
125
+ video = Video(id=str(uuid.uuid4()), youtube_url=req.youtube_url, youtube_id="",
126
+ title="Processing...", channel="", status=VideoStatus.downloading.value)
 
 
 
 
 
 
127
  async with async_session_factory() as session:
128
+ session.add(video)
129
+ await session.commit()
130
  with yt_dlp.YoutubeDL({'quiet': True, 'extract_flat': True}) as ydl:
131
  info = ydl.extract_info(req.youtube_url, download=False)
132
  video.title = info.get('title', 'Unknown')
 
137
  video.status = VideoStatus.completed.value
138
  video.frames_count = 100
139
  video.features_count = 100
 
140
  async with async_session_factory() as session:
141
+ await session.merge(video)
142
+ await session.commit()
143
  return video_to_dict(video)
144
  except Exception as e:
145
+ return {"error": str(e)}
146
 
147
  @app.get("/api/v1/videos/")
148
  async def list_videos(skip: int = 0, limit: int = 50, status: Optional[str] = None):
149
  try:
150
+ await ensure_tables()
151
  async with async_session_factory() as session:
152
  stmt = select(Video).offset(skip).limit(limit)
153
  if status: stmt = stmt.where(Video.status == status)
 
159
 
160
  @app.get("/api/v1/videos/{video_id}")
161
  async def get_video(video_id: str):
162
+ await ensure_tables()
163
+ async with async_session_factory() as session:
164
  result = await session.execute(select(Video).where(Video.id == video_id))
165
  video = result.scalar_one_or_none()
166
  if not video: raise HTTPException(status_code=404, detail="Not found")
 
168
 
169
  @app.delete("/api/v1/videos/{video_id}")
170
  async def delete_video(video_id: str):
171
+ await ensure_tables()
172
+ async with async_session_factory() as session:
173
  result = await session.execute(select(Video).where(Video.id == video_id))
174
  video = result.scalar_one_or_none()
175
  if video: await session.delete(video); await session.commit()
 
181
 
182
  @app.get("/api/v1/index/stats")
183
  async def index_stats():
 
184
  idx_path = "/data/index/index.faiss"
185
  if os.path.exists(idx_path):
186
  size = os.path.getsize(idx_path)
 
203
  async def list_analyses():
204
  return []
205
 
 
206
  if __name__ == "__main__":
207
  import uvicorn
208
  uvicorn.run(app, host="0.0.0.0", port=7860)