tool / app.py
Rajhuggingface4253's picture
Update app.py
3081a67 verified
import json
import random
import string
from typing import List, Optional, Any
from datetime import datetime, timedelta
from fastapi import FastAPI, HTTPException, Depends, Query, Request, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict
from pydantic import HttpUrl, EmailStr
from sqlalchemy import create_engine, Column, String, Integer, Float, Boolean, Text, DateTime, func
from sqlalchemy.orm import declarative_base, sessionmaker, Session
import os
from fastapi.responses import JSONResponse
# =============================================================================
# DATABASE SETUP
# =============================================================================
if os.path.exists("/data"):
DATABASE_DIR = "/data"
print("✅ PRODUCTION MODE: Using Persistent Storage at /data")
else:
DATABASE_DIR = os.path.join(os.getcwd(), "data")
os.makedirs(DATABASE_DIR, exist_ok=True)
print(f"⚠️ LOCAL MODE: Using local storage at {DATABASE_DIR}")
# 2. Set the Database URL
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DATABASE_DIR}/tools.db"
# 3. Create Engine
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# =============================================================================
# DATABASE MODELS
# =============================================================================
class Tool(Base):
__tablename__ = "tools"
id = Column(Integer, primary_key=True, index=True)
slug = Column(String, unique=True, index=True)
name = Column(String, index=True)
description = Column(String)
url = Column(String)
category = Column(String)
tags = Column(String) # JSON string
developer_name = Column(String)
developer_website = Column(String)
developer_email = Column(String)
price = Column(Float, nullable=True)
pricing_model = Column(String, nullable=True)
notes = Column(Text, nullable=True)
submission_date = Column(DateTime, default=func.now())
status = Column(String, default="pending")
featured = Column(Boolean, default=False)
rating = Column(Float, default=0.0)
thumbnail = Column(String, nullable=True)
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
promoted = Column(Boolean, default=False)
# Create tables
Base.metadata.create_all(bind=engine)
# =============================================================================
# PYDANTIC V2 SCHEMAS
# =============================================================================
class Developer(BaseModel):
model_config = ConfigDict(from_attributes=True)
name: str = Field(..., min_length=2, max_length=100)
website: str = Field(..., pattern=r'^https?://') # Simplified URL validation
email: Optional[str] = Field(None, pattern=r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
@field_validator('website')
@classmethod
def validate_website(cls, v: str) -> str:
if not v.startswith(('http://', 'https://')):
return f'https://{v}'
return v
class ToolCreate(BaseModel):
model_config = ConfigDict(from_attributes=True)
name: str = Field(..., min_length=2, max_length=100)
description: str = Field(..., min_length=10, max_length=500)
url: str = Field(..., pattern=r'^https?://')
category: str = Field(...)
tags: List[str] = Field(..., min_length=1)
developer: Developer
price: Optional[float] = Field(None, ge=0)
pricingModel: Optional[str] = Field(None)
notes: Optional[str] = Field(None)
terms: bool = Field(..., description="Must agree to terms")
@field_validator('url')
@classmethod
def validate_url(cls, v: str) -> str:
if not v.startswith(('http://', 'https://')):
return f'https://{v}'
return v
@field_validator('terms')
@classmethod
def validate_terms(cls, v: bool) -> bool:
if not v:
raise ValueError('You must agree to the terms')
return v
class ToolUpdate(BaseModel):
model_config = ConfigDict(from_attributes=True)
name: Optional[str] = Field(None, min_length=2, max_length=100)
description: Optional[str] = Field(None, min_length=10, max_length=500)
url: Optional[str] = Field(None, pattern=r'^https?://')
category: Optional[str] = None
tags: Optional[List[str]] = None
developer: Optional[Developer] = None
price: Optional[float] = Field(None, ge=0)
pricingModel: Optional[str] = None
notes: Optional[str] = None
status: Optional[str] = Field(None, pattern=r'^(pending|approved|rejected)$')
featured: Optional[bool] = None
rating: Optional[float] = Field(None, ge=0, le=5)
thumbnail: Optional[str] = None
promoted: Optional[bool] = None
@field_validator('url')
@classmethod
def validate_url(cls, v: Optional[str]) -> Optional[str]:
if v and not v.startswith(('http://', 'https://')):
return f'https://{v}'
return v
class ToolResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str # Slug
name: str
description: str
url: str
category: str
tags: List[str]
developer: Developer
price: Optional[float] = None
pricingModel: Optional[str] = None
submissionDate: str
status: str
featured: bool
rating: float
thumbnail: Optional[str] = None
updatedAt: Optional[str] = None
promoted: bool = False
notes: Optional[str] = None
# =============================================================================
# FASTAPI APP SETUP
# =============================================================================
app = FastAPI(
title="Ahaa",
description="Ahaa",
version="2.0.0",
)
# CORS Configuration - Allow all origins for development
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# DEPENDENCIES & UTILITIES
# =============================================================================
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def generate_slug():
return 'tool_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=9))
def map_db_to_response(db_tool: Tool) -> ToolResponse:
"""Convert SQLAlchemy model to Pydantic response model"""
try:
tags_list = json.loads(db_tool.tags) if db_tool.tags else []
except:
tags_list = []
return ToolResponse(
id=db_tool.slug,
name=db_tool.name,
description=db_tool.description,
url=db_tool.url,
category=db_tool.category,
tags=tags_list,
developer=Developer(
name=db_tool.developer_name,
website=db_tool.developer_website,
email=db_tool.developer_email
),
price=db_tool.price,
pricingModel=db_tool.pricing_model,
submissionDate=db_tool.submission_date.isoformat(),
status=db_tool.status,
featured=db_tool.featured,
rating=db_tool.rating,
thumbnail=db_tool.thumbnail,
updatedAt=db_tool.updated_at.isoformat() if db_tool.updated_at else None,
notes=db_tool.notes
)
# =============================================================================
# HEALTH CHECK
# =============================================================================
@app.get("/", tags=["Health"])
async def root():
return {
"message": "error",
}
@app.get("/api/health", tags=["Health"])
async def health_check():
return {
"status": "healthy",
}
# =============================================================================
# API ENDPOINTS
# =============================================================================
@app.post("/api/tools",
response_model=ToolResponse,
status_code=status.HTTP_201_CREATED,
tags=["Tools"])
def create_tool(tool: ToolCreate, db: Session = Depends(get_db)):
"""
Create a new tool submission
"""
# Check for duplicate URL
existing = db.query(Tool).filter(Tool.url == tool.url).first()
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="A tool with this URL already exists."
)
# Create database object
db_tool = Tool(
slug=generate_slug(),
name=tool.name,
description=tool.description,
url=tool.url,
category=tool.category,
tags=json.dumps(tool.tags),
developer_name=tool.developer.name,
developer_website=tool.developer.website,
developer_email=tool.developer.email,
price=tool.price,
pricing_model=tool.pricingModel,
notes=tool.notes,
status="pending",
featured=False,
rating=0.0
)
try:
db.add(db_tool)
db.commit()
db.refresh(db_tool)
return map_db_to_response(db_tool)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Database error: {str(e)}"
)
@app.get("/api/tools", response_model=List[ToolResponse], tags=["Tools"])
def get_tools(
skip: int = Query(0, ge=0, description="Number of items to skip"),
limit: int = Query(100, ge=1, le=1000, description="Number of items to return"),
status: Optional[str] = Query(None, description="Filter by status"),
category: Optional[str] = Query(None, description="Filter by category"),
featured: Optional[bool] = Query(None, description="Filter by featured status"),
search: Optional[str] = Query(None, description="Search in name and description"),
db: Session = Depends(get_db)
):
"""
Get paginated list of tools with filtering options
"""
query = db.query(Tool)
# Apply filters
if status:
query = query.filter(Tool.status == status)
if category:
query = query.filter(Tool.category == category)
if featured is not None:
query = query.filter(Tool.featured == featured)
if search:
search_term = f"%{search}%"
query = query.filter(
(Tool.name.ilike(search_term)) |
(Tool.description.ilike(search_term))
)
# Get results
tools = query.order_by(Tool.submission_date.desc()).offset(skip).limit(limit).all()
return [map_db_to_response(t) for t in tools]
@app.get("/api/tools/{slug}", response_model=ToolResponse, tags=["Tools"])
def get_tool(slug: str, db: Session = Depends(get_db)):
"""
Get a specific tool by slug
"""
tool = db.query(Tool).filter(Tool.slug == slug).first()
if not tool:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tool not found"
)
return map_db_to_response(tool)
@app.patch("/api/tools/{slug}", response_model=ToolResponse, tags=["Tools"])
def update_tool(slug: str, tool_update: ToolUpdate, db: Session = Depends(get_db)):
"""
Update a tool
"""
db_tool = db.query(Tool).filter(Tool.slug == slug).first()
if not db_tool:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tool not found"
)
# Update fields if provided
update_data = tool_update.model_dump(exclude_unset=True)
if 'pricingModel' in update_data:
update_data['pricing_model'] = update_data.pop('pricingModel')
if 'developer' in update_data:
developer = update_data.pop('developer')
db_tool.developer_name = developer['name']
db_tool.developer_website = developer['website']
db_tool.developer_email = developer['email']
if 'tags' in update_data:
db_tool.tags = json.dumps(update_data.pop('tags'))
# Update remaining fields
for key, value in update_data.items():
if hasattr(db_tool, key):
setattr(db_tool, key, value)
try:
db.commit()
db.refresh(db_tool)
return map_db_to_response(db_tool)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Update failed: {str(e)}"
)
@app.patch("/api/tools/{tool_id}", response_model=ToolResponse, tags=["Tools"])
def update_tool(tool_id: str, tool_update: ToolUpdate, db: Session = Depends(get_db)):
# 1. Find the tool
db_tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not db_tool:
raise HTTPException(status_code=404, detail="Tool not found")
# 2. Update only the fields provided in the request
# This magic line ensures that sending {"promoted": true} ONLY updates that one flag
update_data = tool_update.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_tool, key, value)
# 3. Save and refresh
db.commit()
db.refresh(db_tool)
# 4. Return the updated tool
return map_db_to_response(db_tool)
@app.delete("/api/tools/{slug}", tags=["Tools"])
def delete_tool(slug: str, db: Session = Depends(get_db)):
"""
Delete a tool
"""
tool = db.query(Tool).filter(Tool.slug == slug).first()
if not tool:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tool not found"
)
try:
db.delete(tool)
db.commit()
return {
"message": "Tool deleted successfully",
"slug": slug
}
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Deletion failed: {str(e)}"
)
@app.get("/api/stats", tags=["Statistics"])
def get_statistics(db: Session = Depends(get_db)):
"""
Get system statistics
"""
try:
total = db.query(Tool).count()
pending = db.query(Tool).filter(Tool.status == "pending").count()
approved = db.query(Tool).filter(Tool.status == "approved").count()
rejected = db.query(Tool).filter(Tool.status == "rejected").count()
featured = db.query(Tool).filter(Tool.featured == True).count()
# Get category distribution
categories = db.query(
Tool.category,
func.count(Tool.id).label('count')
).group_by(Tool.category).all()
# Get recent submissions (last 7 days)
week_ago = datetime.now() - timedelta(days=7)
recent = db.query(Tool).filter(Tool.submission_date >= week_ago).count()
return {
"total": total,
"pending": pending,
"approved": approved,
"rejected": rejected,
"featured": featured,
"recent_submissions": recent,
"categories": [{"name": c[0], "count": c[1]} for c in categories],
"updated": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch statistics: {str(e)}"
)
# =============================================================================
# ERROR HANDLING
# =============================================================================
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={
"detail": exc.detail,
"path": request.url.path,
"method": request.method
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Internal server error",
"error": str(exc),
"path": request.url.path
}
)
# =============================================================================
# MAIN EXECUTION
# =============================================================================
if __name__ == "__main__":
import uvicorn
from fastapi.responses import JSONResponse
from datetime import timedelta
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
log_level="info",
reload=True
)