| import os |
| import io |
| import time |
| import uuid |
| import json |
| import hmac |
| import hashlib |
| import secrets |
| from typing import List, Optional, Any, Dict |
|
|
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Form, Body |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel, EmailStr, Field |
| from PIL import Image |
|
|
| from sqlalchemy import Column, String, Integer, Float, create_engine, text, select |
| from sqlalchemy.orm import declarative_base, Session |
|
|
| from email.message import EmailMessage |
| import smtplib |
|
|
| from model import load_model, segment_image, load_model_variant |
| from utils import colorize_mask, overlay_mask, build_legend_from_mask |
|
|
| |
| MEDIA_ROOT = os.getenv("MEDIA_ROOT", "/data/media") |
| UPLOAD_DIR = os.path.join(MEDIA_ROOT, "uploads") |
| DB_PATH = os.getenv("DB_PATH", "/data/app.db") |
| DEVICE = os.getenv("DEVICE", "cpu") |
|
|
| CORS_ORIGINS = [ |
| o.strip() |
| for o in os.getenv( |
| "CORS_ORIGINS", |
| "https://coneimage.com,https://www.coneimage.com,http://localhost:3000,http://127.0.0.1:3000", |
| ).split(",") |
| if o.strip() |
| ] |
|
|
| os.makedirs(MEDIA_ROOT, exist_ok=True) |
| os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
| MAX_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(10 * 1024 * 1024))) |
| SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7))) |
| EMAIL_CODE_TTL_SECONDS = int(os.getenv("EMAIL_CODE_TTL_SECONDS", str(15 * 60))) |
| PASSWORD_RESET_TTL_SECONDS = int(os.getenv("PASSWORD_RESET_TTL_SECONDS", str(15 * 60))) |
|
|
| |
| SMTP_HOST = os.getenv("SMTP_HOST", "smtp.gmail.com") |
| SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) |
| SMTP_USER = os.getenv("SMTP_USER", "coneimage123@gmail.com") |
| SMTP_PASS = os.getenv("SMTP_PASS", "elao hevi deat ehnn") |
| SMTP_FROM = os.getenv("SMTP_FROM", "coneimage123@gmail.com") |
|
|
| |
| CONTACT_RECIPIENTS = [ |
| e.strip() |
| for e in os.getenv( |
| "CONTACT_RECIPIENTS", |
| "mahmed10.umbc@gmail.com,smxahsan@gmail.com", |
| ).split(",") |
| if e.strip() |
| ] |
|
|
| |
| app = FastAPI(title="ConeImage API", version="1.5.0") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=CORS_ORIGINS if CORS_ORIGINS else ["*"], |
| allow_credentials=False, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| max_age=86400, |
| ) |
|
|
| |
| engine = create_engine(f"sqlite:///{DB_PATH}", future=True) |
| Base = declarative_base() |
|
|
| class Job(Base): |
| __tablename__ = "jobs" |
| id = Column(String, primary_key=True) |
| created_at = Column(Float, index=True) |
| input_path = Column(String) |
| input_name = Column(String) |
| mask_path = Column(String) |
| overlay_path = Column(String) |
| model = Column(String) |
| device = Column(String) |
| width = Column(Integer) |
| height = Column(Integer) |
| user_id = Column(String, index=True, nullable=True) |
|
|
| class Upload(Base): |
| """ |
| Keeps metadata for uploaded files so we can recover the original filename |
| in the two-step flow (POST /api/upload -> POST /api/segment with upload_id). |
| """ |
| __tablename__ = "uploads" |
| upload_id = Column(String, primary_key=True) |
| created_at = Column(Float, index=True) |
| stored_name = Column(String) |
| orig_name = Column(String) |
| width = Column(Integer) |
| height = Column(Integer) |
| user_id = Column(String, index=True, nullable=True) |
|
|
| class User(Base): |
| __tablename__ = "users" |
| id = Column(String, primary_key=True) |
| created_at = Column(Float, index=True) |
| email = Column(String, unique=True, index=True) |
| password_hash = Column(String) |
| name = Column(String, nullable=True) |
| is_verified = Column(Integer, default=0) |
| verified_at = Column(Float, nullable=True) |
| preferred_model = Column(String, default="green") |
|
|
| class SessionToken(Base): |
| __tablename__ = "sessions" |
| token = Column(String, primary_key=True) |
| user_id = Column(String, index=True) |
| created_at = Column(Float, index=True) |
| expires_at = Column(Float, index=True) |
| user_agent = Column(String, nullable=True) |
|
|
| class Contact(Base): |
| __tablename__ = "contacts" |
| id = Column(String, primary_key=True) |
| created_at = Column(Float, index=True) |
| name = Column(String) |
| email = Column(String) |
| subject = Column(String, nullable=True) |
| message = Column(String) |
|
|
| class Feedback(Base): |
| __tablename__ = "feedback" |
| id = Column(String, primary_key=True) |
| created_at = Column(Float, index=True) |
| name = Column(String, nullable=True) |
| email = Column(String, nullable=True) |
| |
| feedback = Column(String) |
|
|
| class Survey(Base): |
| """ |
| Stores star-based site survey results (no notes). |
| Ratings use a 1..5 scale where 1=worse, 5=best. |
| """ |
| __tablename__ = "surveys" |
| id = Column(String, primary_key=True) |
| created_at = Column(Float, index=True) |
| role = Column(String, nullable=True) |
| use_case = Column(String, nullable=True) |
| model_works = Column(Integer, nullable=True) |
| image_visualization = Column(Integer, nullable=True) |
| prompt_input = Column(Integer, nullable=True) |
| processing_speed = Column(Integer, nullable=True) |
| accessibility = Column(Integer, nullable=True) |
| overall_experience = Column(Integer, nullable=True) |
|
|
| class EmailCode(Base): |
| __tablename__ = "email_codes" |
| id = Column(String, primary_key=True) |
| user_id = Column(String, index=True) |
| purpose = Column(String, index=True) |
| code_hash = Column(String) |
| code_salt = Column(String) |
| created_at = Column(Float, index=True) |
| expires_at = Column(Float, index=True) |
| consumed_at = Column(Float, nullable=True) |
|
|
| Base.metadata.create_all(engine) |
|
|
| def _maybe_migrate_jobs_add_input_name(): |
| """ |
| Add 'input_name' column to jobs if missing (SQLite). |
| Safe to run on every startup. |
| """ |
| with engine.connect() as conn: |
| cols = conn.execute(text("PRAGMA table_info(jobs)")).all() |
| existing_cols = {c[1] for c in cols} |
| if "input_name" not in existing_cols: |
| conn.execute(text("ALTER TABLE jobs ADD COLUMN input_name TEXT")) |
| conn.commit() |
|
|
| _maybe_migrate_jobs_add_input_name() |
|
|
| def _maybe_migrate_contacts_add_subject(): |
| with engine.connect() as conn: |
| cols = conn.execute(text("PRAGMA table_info(contacts)")).all() |
| existing_cols = {c[1] for c in cols} |
| if "subject" not in existing_cols: |
| conn.execute(text("ALTER TABLE contacts ADD COLUMN subject TEXT")) |
| conn.commit() |
|
|
| def _maybe_migrate_jobs_add_user_id(): |
| with engine.connect() as conn: |
| cols = conn.execute(text("PRAGMA table_info(jobs)")).all() |
| existing_cols = {c[1] for c in cols} |
| if "user_id" not in existing_cols: |
| conn.execute(text("ALTER TABLE jobs ADD COLUMN user_id TEXT")) |
| conn.commit() |
|
|
|
|
| def _maybe_migrate_uploads_add_user_id(): |
| with engine.connect() as conn: |
| cols = conn.execute(text("PRAGMA table_info(uploads)")).all() |
| existing_cols = {c[1] for c in cols} |
| if "user_id" not in existing_cols: |
| conn.execute(text("ALTER TABLE uploads ADD COLUMN user_id TEXT")) |
| conn.commit() |
|
|
|
|
| def _maybe_migrate_users_add_verification(): |
| with engine.connect() as conn: |
| cols = conn.execute(text("PRAGMA table_info(users)")).all() |
| existing_cols = {c[1] for c in cols} |
| if "is_verified" not in existing_cols: |
| conn.execute(text("ALTER TABLE users ADD COLUMN is_verified INTEGER DEFAULT 0")) |
| conn.commit() |
| if "verified_at" not in existing_cols: |
| conn.execute(text("ALTER TABLE users ADD COLUMN verified_at REAL")) |
| conn.commit() |
|
|
|
|
| def _maybe_migrate_users_add_preferred_model(): |
| with engine.connect() as conn: |
| cols = conn.execute(text("PRAGMA table_info(users)")).all() |
| existing_cols = {c[1] for c in cols} |
| if "preferred_model" not in existing_cols: |
| conn.execute(text("ALTER TABLE users ADD COLUMN preferred_model TEXT")) |
| conn.execute(text("UPDATE users SET preferred_model = 'green' WHERE preferred_model IS NULL")) |
| conn.commit() |
|
|
|
|
| _maybe_migrate_jobs_add_user_id() |
| _maybe_migrate_uploads_add_user_id() |
| _maybe_migrate_users_add_verification() |
| _maybe_migrate_users_add_preferred_model() |
|
|
| def _hash_password(password: str) -> str: |
| salt = secrets.token_bytes(16) |
| hashed = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100_000) |
| return f"{salt.hex()}:{hashed.hex()}" |
|
|
|
|
| def _verify_password(password: str, stored: str) -> bool: |
| try: |
| salt_hex, hash_hex = stored.split(":", 1) |
| except ValueError: |
| return False |
| try: |
| salt = bytes.fromhex(salt_hex) |
| expected = bytes.fromhex(hash_hex) |
| except ValueError: |
| return False |
| computed = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100_000) |
| return hmac.compare_digest(computed, expected) |
|
|
|
|
| def _create_session(user: User, ses: Session, user_agent: Optional[str] = None) -> str: |
| token = secrets.token_urlsafe(48) |
| now = time.time() |
| ses.add( |
| SessionToken( |
| token=token, |
| user_id=user.id, |
| created_at=now, |
| expires_at=now + SESSION_TTL_SECONDS, |
| user_agent=user_agent, |
| ) |
| ) |
| return token |
|
|
|
|
| def _extract_bearer_token(request: Request) -> Optional[str]: |
| auth = request.headers.get("Authorization") |
| if not auth: |
| return None |
| if isinstance(auth, str) and auth.lower().startswith("bearer "): |
| return auth[7:].strip() or None |
| return None |
|
|
|
|
| def _get_user_from_token(token: Optional[str], ses: Session, *, require: bool = False) -> Optional[User]: |
| if not token: |
| if require: |
| raise HTTPException(status_code=401, detail="Missing authorization") |
| return None |
| session_row = ses.get(SessionToken, token) |
| if not session_row or (session_row.expires_at and session_row.expires_at < time.time()): |
| if require: |
| raise HTTPException(status_code=401, detail="Invalid or expired session") |
| return None |
| user = ses.get(User, session_row.user_id) |
| if not user: |
| if require: |
| raise HTTPException(status_code=401, detail="Invalid session") |
| return None |
| return user |
|
|
|
|
| def _serialize_user(user: User) -> Dict[str, Any]: |
| return { |
| "id": user.id, |
| "email": user.email, |
| "name": user.name, |
| "created_at": user.created_at, |
| "is_verified": bool(user.is_verified), |
| "verified_at": user.verified_at, |
| "preferred_model": _normalize_model_key(getattr(user, "preferred_model", None)), |
| } |
|
|
| _maybe_migrate_contacts_add_subject() |
|
|
|
|
| def _generate_numeric_code(length: int = 6) -> str: |
| alphabet = "0123456789" |
| return "".join(secrets.choice(alphabet) for _ in range(length)) |
|
|
|
|
| def _hash_code(code: str, salt: str) -> str: |
| return hashlib.sha256(f"{code}:{salt}".encode("utf-8")).hexdigest() |
|
|
|
|
| def _issue_email_code(ses: Session, user: User, purpose: str, ttl_seconds: int) -> str: |
| code = _generate_numeric_code() |
| salt = secrets.token_hex(8) |
| now = time.time() |
| expires = now + ttl_seconds |
|
|
| active = ses.scalars( |
| select(EmailCode).where( |
| EmailCode.user_id == user.id, |
| EmailCode.purpose == purpose, |
| EmailCode.consumed_at.is_(None), |
| ) |
| ).all() |
| for entry in active: |
| entry.consumed_at = entry.consumed_at or now |
|
|
| ses.add( |
| EmailCode( |
| id=str(uuid.uuid4()), |
| user_id=user.id, |
| purpose=purpose, |
| code_hash=_hash_code(code, salt), |
| code_salt=salt, |
| created_at=now, |
| expires_at=expires, |
| consumed_at=None, |
| ) |
| ) |
| return code |
|
|
|
|
| def _verify_email_code(ses: Session, user: User, purpose: str, code: str) -> Optional[EmailCode]: |
| now = time.time() |
| candidates = ses.scalars( |
| select(EmailCode).where( |
| EmailCode.user_id == user.id, |
| EmailCode.purpose == purpose, |
| EmailCode.consumed_at.is_(None), |
| EmailCode.expires_at > now, |
| ) |
| ).all() |
|
|
| for candidate in candidates: |
| expected = _hash_code(code, candidate.code_salt) |
| if hmac.compare_digest(expected, candidate.code_hash): |
| candidate.consumed_at = now |
| for other in candidates: |
| if other.id != candidate.id: |
| other.consumed_at = other.consumed_at or now |
| return candidate |
| return None |
|
|
|
|
| def _send_verification_email(user: User, code: str) -> None: |
| subject = "ConeImage verification code" |
| minutes = max(1, EMAIL_CODE_TTL_SECONDS // 60) |
| body = ( |
| "Hello!\n\n" |
| "Use the verification code below to activate your ConeImage account.\n\n" |
| f"Code: {code}\n\n" |
| f"This code will expire in {minutes} minute{'s' if minutes != 1 else ''}. If you did not sign up, you can ignore this email." |
| ) |
| _send_mail([user.email], subject, body) |
|
|
|
|
| def _send_password_reset_email(user: User, code: str) -> None: |
| subject = "ConeImage password reset" |
| minutes = max(1, PASSWORD_RESET_TTL_SECONDS // 60) |
| body = ( |
| "We received a request to reset your ConeImage password.\n\n" |
| f"Use this code to set a new password: {code}\n\n" |
| f"This code will expire in {minutes} minute{'s' if minutes != 1 else ''}. If you did not request a reset, you can delete this email." |
| ) |
| _send_mail([user.email], subject, body) |
|
|
| |
| DEFAULT_MODEL_KEY = "green" |
| DEFAULT_MODEL, SENTENCE_MODEL, PCA_PARAMS = load_model(device=DEVICE) |
| MODEL_REGISTRY: Dict[str, Any] = {DEFAULT_MODEL_KEY: DEFAULT_MODEL} |
|
|
|
|
| def _normalize_model_key(key: Optional[str]) -> str: |
| return "orange" if key == "orange" else DEFAULT_MODEL_KEY |
|
|
|
|
| def _ensure_model_loaded(key: Optional[str]): |
| normalized = _normalize_model_key(key) |
| model = MODEL_REGISTRY.get(normalized) |
| if model is None: |
| model = load_model_variant(normalized, device=DEVICE) |
| MODEL_REGISTRY[normalized] = model |
| return model, normalized |
|
|
| |
| |
| app.mount("/media", StaticFiles(directory=MEDIA_ROOT), name="media") |
|
|
| |
| class JobOut(BaseModel): |
| id: str |
| created_at: float |
| input_url: str |
| mask_url: str |
| overlay_url: str |
| model: str |
| device: str |
| width: int |
| height: int |
| input_name: str |
| legend: list[dict] | None = None |
|
|
| class UploadOut(BaseModel): |
| upload_id: str |
| created_at: float |
| input_url: str |
| width: int |
| height: int |
|
|
| class ContactIn(BaseModel): |
| name: str = Field(min_length=1, max_length=120) |
| email: EmailStr |
| subject: Optional[str] = Field(default=None, max_length=200) |
| message: str = Field(min_length=1, max_length=5000) |
|
|
| class FeedbackIn(BaseModel): |
| |
| name: Optional[str] = Field(default=None, max_length=120) |
| email: Optional[EmailStr] = None |
| feedback: str = Field(min_length=1, max_length=10000) |
|
|
|
|
| class UserOut(BaseModel): |
| id: str |
| email: EmailStr |
| name: Optional[str] = None |
| created_at: float |
| is_verified: bool |
| verified_at: Optional[float] = None |
| preferred_model: str |
|
|
|
|
| class AuthResponse(BaseModel): |
| token: str |
| user: UserOut |
|
|
|
|
| class RegisterIn(BaseModel): |
| email: EmailStr |
| password: str = Field(min_length=8, max_length=256) |
| name: str = Field(min_length=1, max_length=120) |
|
|
|
|
| class RegisterOut(BaseModel): |
| detail: str |
| requires_verification: bool = True |
|
|
|
|
| class LoginIn(BaseModel): |
| email: EmailStr |
| password: str = Field(min_length=1, max_length=256) |
|
|
|
|
| class AuthMeOut(BaseModel): |
| user: UserOut |
|
|
|
|
| class VerifyResendIn(BaseModel): |
| email: EmailStr |
|
|
|
|
| class VerifyConfirmIn(BaseModel): |
| email: EmailStr |
| code: str = Field(min_length=6, max_length=6, pattern=r"^\d{6}$") |
|
|
|
|
| class GenericMessageOut(BaseModel): |
| detail: str |
|
|
|
|
| class PasswordForgotIn(BaseModel): |
| email: EmailStr |
|
|
|
|
| class PasswordResetIn(BaseModel): |
| email: EmailStr |
| code: str = Field(min_length=6, max_length=6, pattern=r"^\d{6}$") |
| new_password: str = Field(min_length=8, max_length=256) |
|
|
| |
| def _send_mail(to: list[str], subject: str, body: str, reply_to: Optional[str] = None) -> bool: |
| """Best-effort SMTP email. Returns True on success, False otherwise.""" |
| if not SMTP_HOST or not to: |
| return False |
| msg = EmailMessage() |
| msg["Subject"] = subject |
| msg["From"] = SMTP_FROM |
| msg["To"] = ", ".join(to) |
| if reply_to: |
| msg["Reply-To"] = reply_to |
| msg.set_content(body) |
|
|
| try: |
| with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=15) as s: |
| s.starttls() |
| if SMTP_USER: |
| s.login(SMTP_USER, SMTP_PASS) |
| s.send_message(msg) |
| return True |
| except Exception: |
| return False |
|
|
| def _extract_int(v: Any) -> Optional[int]: |
| try: |
| iv = int(v) |
| if 1 <= iv <= 5: |
| return iv |
| except Exception: |
| pass |
| return None |
|
|
| def _try_parse_site_survey(raw_feedback: str) -> Optional[Dict[str, Any]]: |
| """ |
| If the 'feedback' field contains a JSON string with tag='SITE_SURVEY', |
| parse and return a dict with normalized keys. Otherwise return None. |
| """ |
| try: |
| data = json.loads(raw_feedback) |
| if not isinstance(data, dict): |
| return None |
| if str(data.get("tag")).upper() != "SITE_SURVEY": |
| return None |
| ratings = data.get("ratings") or {} |
| return { |
| "role": (data.get("role") or None), |
| "use_case": (data.get("use_case") or None), |
| "model_works": _extract_int(ratings.get("model_works")), |
| "image_visualization": _extract_int(ratings.get("image_visualization")), |
| "prompt_input": _extract_int(ratings.get("prompt_input")), |
| "processing_speed": _extract_int(ratings.get("processing_speed")), |
| "accessibility": _extract_int(ratings.get("accessibility")), |
| "overall_experience": _extract_int(ratings.get("overall_experience")), |
| } |
| except Exception: |
| return None |
|
|
| |
| @app.middleware("http") |
| async def add_cache_headers(request: Request, call_next): |
| resp = await call_next(request) |
| if request.url.path.startswith("/media/"): |
| resp.headers["Cache-Control"] = "public, max-age=2592000, immutable" |
| return resp |
|
|
| |
| @app.get("/api/health") |
| def health(): |
| return {"status": "ok", "device": DEVICE, "model_key": DEFAULT_MODEL_KEY} |
|
|
|
|
| @app.post("/api/auth/register", response_model=RegisterOut, status_code=201) |
| def register_user(payload: RegisterIn): |
| email_norm = str(payload.email).strip().lower() |
| name_norm = payload.name.strip() |
| if not name_norm: |
| raise HTTPException(status_code=422, detail="Name must not be empty") |
| with Session(engine) as ses: |
| existing = ses.scalar(select(User).where(User.email == email_norm)) |
| if existing: |
| raise HTTPException(status_code=409, detail="Email already registered") |
|
|
| user = User( |
| id=str(uuid.uuid4()), |
| created_at=time.time(), |
| email=email_norm, |
| password_hash=_hash_password(payload.password), |
| name=name_norm, |
| is_verified=0, |
| verified_at=None, |
| preferred_model=DEFAULT_MODEL_KEY, |
| ) |
| ses.add(user) |
| ses.flush() |
| code = _issue_email_code(ses, user, "verify_email", EMAIL_CODE_TTL_SECONDS) |
| ses.commit() |
| _send_verification_email(user, code) |
|
|
| return RegisterOut( |
| detail="Account created. Check your email for a verification code to activate it.", |
| requires_verification=True, |
| ) |
|
|
|
|
| @app.post("/api/auth/login", response_model=AuthResponse) |
| def login_user(payload: LoginIn, request: Request): |
| email_norm = str(payload.email).strip().lower() |
| with Session(engine) as ses: |
| user = ses.scalar(select(User).where(User.email == email_norm)) |
| if not user or not _verify_password(payload.password, user.password_hash): |
| raise HTTPException(status_code=401, detail="Invalid email or password") |
|
|
| if not user.is_verified: |
| raise HTTPException( |
| status_code=403, |
| detail="Email not verified. Use the code sent to your inbox to activate your account.", |
| ) |
|
|
| token = _create_session(user, ses, request.headers.get("User-Agent")) |
| ses.commit() |
|
|
| return AuthResponse(token=token, user=UserOut(**_serialize_user(user))) |
|
|
|
|
| @app.post("/api/auth/verify/resend", response_model=GenericMessageOut) |
| def resend_verification(payload: VerifyResendIn): |
| email_norm = str(payload.email).strip().lower() |
| with Session(engine) as ses: |
| user = ses.scalar(select(User).where(User.email == email_norm)) |
| if not user: |
| return GenericMessageOut(detail="If an account exists, a verification email has been sent.") |
|
|
| if user.is_verified: |
| return GenericMessageOut(detail="This email is already verified. You can sign in.") |
|
|
| code = _issue_email_code(ses, user, "verify_email", EMAIL_CODE_TTL_SECONDS) |
| ses.commit() |
| _send_verification_email(user, code) |
|
|
| return GenericMessageOut(detail="Verification code sent. Check your email.") |
|
|
|
|
| @app.post("/api/auth/verify/confirm", response_model=AuthResponse) |
| def confirm_verification(payload: VerifyConfirmIn, request: Request): |
| email_norm = str(payload.email).strip().lower() |
| code = payload.code.strip() |
|
|
| with Session(engine) as ses: |
| user = ses.scalar(select(User).where(User.email == email_norm)) |
| if not user: |
| raise HTTPException(status_code=400, detail="Invalid or expired verification code") |
|
|
| match = _verify_email_code(ses, user, "verify_email", code) |
| if not match: |
| raise HTTPException(status_code=400, detail="Invalid or expired verification code") |
|
|
| user.is_verified = 1 |
| user.verified_at = user.verified_at or time.time() |
| token = _create_session(user, ses, request.headers.get("User-Agent")) |
| ses.commit() |
|
|
| return AuthResponse(token=token, user=UserOut(**_serialize_user(user))) |
|
|
|
|
| @app.post("/api/auth/password/forgot", response_model=GenericMessageOut) |
| def forgot_password(payload: PasswordForgotIn): |
| email_norm = str(payload.email).strip().lower() |
| with Session(engine) as ses: |
| user = ses.scalar(select(User).where(User.email == email_norm)) |
| if not user or not user.is_verified: |
| |
| return GenericMessageOut(detail="If an account exists, you'll receive a reset code shortly.") |
|
|
| code = _issue_email_code(ses, user, "reset_password", PASSWORD_RESET_TTL_SECONDS) |
| ses.commit() |
| _send_password_reset_email(user, code) |
|
|
| return GenericMessageOut(detail="Check your email for a password reset code.") |
|
|
|
|
| @app.post("/api/auth/password/reset", response_model=GenericMessageOut) |
| def reset_password(payload: PasswordResetIn): |
| email_norm = str(payload.email).strip().lower() |
| code = payload.code.strip() |
|
|
| with Session(engine) as ses: |
| user = ses.scalar(select(User).where(User.email == email_norm)) |
| if not user: |
| raise HTTPException(status_code=400, detail="Invalid or expired reset code") |
|
|
| match = _verify_email_code(ses, user, "reset_password", code) |
| if not match: |
| raise HTTPException(status_code=400, detail="Invalid or expired reset code") |
|
|
| user.password_hash = _hash_password(payload.new_password) |
| if not user.is_verified: |
| user.is_verified = 1 |
| user.verified_at = time.time() |
| ses.execute(text("DELETE FROM sessions WHERE user_id = :uid"), {"uid": user.id}) |
| ses.commit() |
|
|
| return GenericMessageOut(detail="Password updated. You can now sign in with your new password.") |
|
|
|
|
| @app.get("/api/auth/me", response_model=AuthMeOut) |
| def whoami(request: Request): |
| token = _extract_bearer_token(request) |
| with Session(engine) as ses: |
| user = _get_user_from_token(token, ses, require=True) |
| return AuthMeOut(user=UserOut(**_serialize_user(user))) |
|
|
|
|
| @app.post("/api/auth/logout") |
| def logout_user(request: Request): |
| token = _extract_bearer_token(request) |
| if not token: |
| raise HTTPException(status_code=401, detail="Missing authorization") |
| with Session(engine) as ses: |
| sess = ses.get(SessionToken, token) |
| if not sess: |
| raise HTTPException(status_code=401, detail="Invalid session") |
| ses.delete(sess) |
| ses.commit() |
| return {"status": "ok"} |
|
|
| @app.post("/api/model/select") |
| def select_model(request: Request, model_key: str = Form(...)): |
| """ |
| Update the preferred segmentation model for the authenticated user. |
| Accepts 'green' or 'orange'. Anything else falls back to 'green'. |
| """ |
|
|
| token = _extract_bearer_token(request) |
| with Session(engine) as ses: |
| user = _get_user_from_token(token, ses, require=True) |
| normalized = _normalize_model_key(model_key) |
| user.preferred_model = normalized |
| ses.commit() |
|
|
| _ensure_model_loaded(normalized) |
| return {"status": "ok", "model_key": normalized, "device": DEVICE} |
|
|
|
|
| @app.get("/api/model/current") |
| def current_model(request: Request): |
| token = _extract_bearer_token(request) |
| with Session(engine) as ses: |
| user = _get_user_from_token(token, ses) |
| key = _normalize_model_key(user.preferred_model if user else None) |
| return {"model_key": key, "device": DEVICE} |
|
|
| @app.post("/api/upload", response_model=UploadOut) |
| async def upload(request: Request, file: UploadFile = File(...)): |
| |
| raw = await file.read() |
| if len(raw) > MAX_BYTES: |
| raise HTTPException(status_code=413, detail="Image too large (max 10MB)") |
| try: |
| with Image.open(io.BytesIO(raw)) as probe: |
| probe.verify() |
| img = Image.open(io.BytesIO(raw)).convert("RGB") |
| except Exception: |
| raise HTTPException(status_code=400, detail="Invalid image") |
|
|
| w, h = img.size |
| upload_id = str(uuid.uuid4()) |
| ts = time.time() |
|
|
| stored_name = f"{upload_id}_input.jpg" |
| in_path = os.path.join(UPLOAD_DIR, stored_name) |
| img.save(in_path, quality=92, optimize=True) |
|
|
| |
| with Session(engine) as ses: |
| user = _get_user_from_token(_extract_bearer_token(request), ses) |
| user_id = user.id if user else None |
| ses.add( |
| Upload( |
| upload_id=upload_id, |
| created_at=ts, |
| stored_name=stored_name, |
| orig_name=os.path.basename(file.filename) if file.filename else "unknown", |
| width=w, |
| height=h, |
| user_id=user_id, |
| ) |
| ) |
| ses.commit() |
|
|
| input_url = request.url_for("media", path=f"uploads/{stored_name}") |
|
|
| return UploadOut( |
| upload_id=upload_id, |
| created_at=ts, |
| input_url=str(input_url), |
| width=w, |
| height=h, |
| ) |
|
|
| @app.post("/api/segment", response_model=JobOut) |
| async def segment( |
| request: Request, |
| file: Optional[UploadFile] = File(None), |
| upload_id: Optional[str] = Form(None), |
| prompt: Optional[str] = Form(None), |
| ): |
| """ |
| Preferred: pass upload_id (image already uploaded via /api/upload). |
| Backward compatibility: can still post a 'file'. |
| If 'prompt' is provided, segmentation will be prompt-guided. |
| """ |
|
|
| |
| if prompt: |
| prompt_list = [p.strip() for p in prompt.split(",") if p.strip()] |
| if prompt_list: |
| if prompt_list[0].lower() != "background": |
| prompt_list = ["background"] + prompt_list |
| else: |
| prompt_list = None |
| else: |
| prompt_list = None |
|
|
| token = _extract_bearer_token(request) |
| user_id: Optional[str] = None |
| user_model_key = DEFAULT_MODEL_KEY |
| if token: |
| with Session(engine) as ses: |
| user = _get_user_from_token(token, ses) |
| if user: |
| user_id = user.id |
| user_model_key = _normalize_model_key(getattr(user, "preferred_model", None)) |
|
|
| |
| if upload_id and not file: |
| |
| in_name = f"{upload_id}_input.jpg" |
| in_path = os.path.join(UPLOAD_DIR, in_name) |
| if not os.path.exists(in_path): |
| raise HTTPException(status_code=404, detail="Upload not found") |
|
|
| try: |
| img = Image.open(in_path).convert("RGB") |
| except Exception: |
| raise HTTPException(status_code=400, detail="Invalid image") |
|
|
| |
| with Session(engine) as ses: |
| upl = ses.get(Upload, upload_id) |
| if upl: |
| if upl.user_id and upl.user_id != user_id: |
| raise HTTPException(status_code=403, detail="Upload does not belong to this account") |
| orig_name = upl.orig_name or in_name |
| else: |
| orig_name = in_name |
| else: |
| |
| if not file: |
| raise HTTPException(status_code=400, detail="No image or upload_id provided") |
| raw = await file.read() |
| if len(raw) > MAX_BYTES: |
| raise HTTPException(status_code=413, detail="Image too large (max 10MB)") |
| try: |
| with Image.open(io.BytesIO(raw)) as probe: |
| probe.verify() |
| img = Image.open(io.BytesIO(raw)).convert("RGB") |
| except Exception: |
| raise HTTPException(status_code=400, detail="Invalid image") |
| orig_name = os.path.basename(file.filename) if file.filename else "unknown" |
|
|
| w, h = img.size |
|
|
| |
| model, user_model_key = _ensure_model_loaded(user_model_key) |
| |
| try: |
| if prompt_list: |
| mask = segment_image(model, img, prompt_list, SENTENCE_MODEL, PCA_PARAMS, device=DEVICE) |
| else: |
| mask = segment_image(model, img, device=DEVICE) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Inference failed: {e}") |
|
|
| dataset = "ade" if user_model_key == "orange" else "coco" |
|
|
| |
| if prompt_list: |
| color = colorize_mask(mask, classes=prompt_list) |
| legend = build_legend_from_mask(mask, classes=prompt_list) |
| else: |
| color = colorize_mask(mask, dataset=dataset) |
| legend = build_legend_from_mask(mask, dataset=dataset) |
| over = overlay_mask(img, color, alpha=0.45) |
|
|
| |
| job_id = str(uuid.uuid4()) |
| ts = time.time() |
| in_name_final = f"{job_id}_input.jpg" |
| mask_name = f"{job_id}_mask.png" |
| over_name = f"{job_id}_overlay.jpg" |
| legend_name = f"{job_id}_legend.json" |
|
|
| in_path_final = os.path.join(MEDIA_ROOT, in_name_final) |
| mask_path = os.path.join(MEDIA_ROOT, mask_name) |
| over_path = os.path.join(MEDIA_ROOT, over_name) |
| legend_path = os.path.join(MEDIA_ROOT, legend_name) |
|
|
| img.save(in_path_final, quality=92, optimize=True) |
| color.save(mask_path) |
| over.save(over_path, quality=92, optimize=True) |
|
|
| try: |
| with open(legend_path, "w", encoding="utf-8") as f: |
| json.dump(legend, f) |
| except Exception: |
| legend_path = None |
|
|
| |
| with Session(engine) as ses: |
| ses.add( |
| Job( |
| id=job_id, |
| created_at=ts, |
| input_path=in_name_final, |
| input_name=orig_name or "unknown", |
| mask_path=mask_name, |
| overlay_path=over_name, |
| model="Meruformerb4_entailonly", |
| device=DEVICE, |
| width=w, |
| height=h, |
| user_id=user_id, |
| ) |
| ) |
| ses.commit() |
|
|
| input_url = request.url_for("media", path=in_name_final) |
| mask_url = request.url_for("media", path=mask_name) |
| overlay_url = request.url_for("media", path=over_name) |
|
|
| return JobOut( |
| id=job_id, |
| created_at=ts, |
| input_url=str(input_url), |
| mask_url=str(mask_url), |
| overlay_url=str(overlay_url), |
| model="Meruformerb4_entailonly", |
| device=DEVICE, |
| width=w, |
| height=h, |
| input_name=orig_name or "unknown", |
| legend=legend, |
| ) |
|
|
| @app.get("/api/history", response_model=List[JobOut]) |
| def history(request: Request, limit: int = 20, offset: int = 0): |
| token = _extract_bearer_token(request) |
| with Session(engine) as ses: |
| user = _get_user_from_token(token, ses) |
| params: Dict[str, Any] = {"limit": limit, "offset": offset} |
| if user: |
| query = text( |
| """ |
| SELECT id, created_at, input_path, mask_path, overlay_path, model, device, width, height, input_name |
| FROM jobs |
| WHERE user_id = :user_id |
| ORDER BY created_at DESC |
| LIMIT :limit OFFSET :offset |
| """ |
| ) |
| params["user_id"] = user.id |
| else: |
| query = text( |
| """ |
| SELECT id, created_at, input_path, mask_path, overlay_path, model, device, width, height, input_name |
| FROM jobs |
| WHERE user_id IS NULL |
| ORDER BY created_at DESC |
| LIMIT :limit OFFSET :offset |
| """ |
| ) |
|
|
| rows = ses.execute(query, params).all() |
|
|
| out: List[JobOut] = [] |
| for r in rows: |
| job_id = r[0] |
| legend = None |
| sidecar = os.path.join(MEDIA_ROOT, f"{job_id}_legend.json") |
| if os.path.exists(sidecar): |
| try: |
| with open(sidecar, "r", encoding="utf-8") as f: |
| legend = json.load(f) |
| except Exception: |
| legend = None |
|
|
| out.append( |
| JobOut( |
| id=job_id, |
| created_at=r[1], |
| input_url=str(request.url_for("media", path=r[2])), |
| mask_url=str(request.url_for("media", path=r[3])), |
| overlay_url=str(request.url_for("media", path=r[4])), |
| model=r[5], |
| device=r[6], |
| width=r[7], |
| height=r[8], |
| input_name=r[9] or "unknown", |
| legend=legend, |
| ) |
| ) |
| return out |
|
|
|
|
| @app.delete("/api/history/{job_id}") |
| def delete_job(job_id: str, request: Request): |
| with Session(engine) as ses: |
| user = _get_user_from_token(_extract_bearer_token(request), ses) |
| row = ses.get(Job, job_id) |
| if not row: |
| raise HTTPException(status_code=404, detail="Job not found") |
| if row.user_id: |
| if not user: |
| raise HTTPException(status_code=401, detail="Authentication required") |
| if user.id != row.user_id: |
| raise HTTPException(status_code=403, detail="Forbidden") |
| |
| for path in [row.input_path, row.mask_path, row.overlay_path]: |
| try: |
| os.remove(os.path.join(MEDIA_ROOT, path)) |
| except Exception: |
| pass |
| |
| try: |
| os.remove(os.path.join(MEDIA_ROOT, f"{job_id}_legend.json")) |
| except Exception: |
| pass |
|
|
| ses.delete(row) |
| ses.commit() |
| return {"status": "deleted", "id": job_id} |
|
|
|
|
| @app.delete("/api/history") |
| def clear_history(request: Request): |
| with Session(engine) as ses: |
| user = _get_user_from_token(_extract_bearer_token(request), ses) |
| if user: |
| rows = ses.query(Job).filter(Job.user_id == user.id).all() |
| else: |
| rows = ses.query(Job).filter(Job.user_id.is_(None)).all() |
| for row in rows: |
| for path in [row.input_path, row.mask_path, row.overlay_path]: |
| try: |
| os.remove(os.path.join(MEDIA_ROOT, path)) |
| except Exception: |
| pass |
| try: |
| os.remove(os.path.join(MEDIA_ROOT, f"{row.id}_legend.json")) |
| except Exception: |
| pass |
| ses.delete(row) |
| ses.commit() |
| return {"status": "cleared"} |
|
|
| |
|
|
| @app.post("/api/contact") |
| def create_contact(payload: ContactIn): |
| ts = time.time() |
| row_id = str(uuid.uuid4()) |
| with Session(engine) as ses: |
| ses.add(Contact( |
| id=row_id, |
| created_at=ts, |
| name=payload.name, |
| email=str(payload.email), |
| subject=payload.subject, |
| message=payload.message |
| )) |
| ses.commit() |
|
|
| |
| subj = (payload.subject or "").strip() |
| email_subject = f"ConeImage: {subj}" if subj else "ConeImage: No subject" |
|
|
| |
| try: |
| _send_mail( |
| to=CONTACT_RECIPIENTS, |
| subject=email_subject, |
| body=( |
| f"Name: {payload.name}\n" |
| f"Email: {payload.email}\n" |
| f"Subject: {subj or 'No subject'}\n" |
| f"Message:\n{payload.message}\n\n" |
| f"Submitted at: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ts))} UTC" |
| ), |
| reply_to=str(payload.email), |
| ) |
| except Exception: |
| pass |
|
|
| return {"status": "ok", "id": row_id, "created_at": ts} |
|
|
| @app.post("/api/feedback") |
| def create_feedback(payload: FeedbackIn): |
| """ |
| Accepts plain feedback (no rating) OR a survey JSON (tag='SITE_SURVEY') |
| posted as the 'feedback' string. If a survey is detected, it is saved |
| into the 'surveys' table; otherwise, it is saved into 'feedback'. |
| For plain feedback, send an email that includes only the submitter's email. |
| """ |
| ts = time.time() |
| row_id = str(uuid.uuid4()) |
|
|
| |
| survey = _try_parse_site_survey(payload.feedback) |
| if survey is not None: |
| with Session(engine) as ses: |
| ses.add(Survey( |
| id=row_id, |
| created_at=ts, |
| role=survey["role"], |
| use_case=survey["use_case"], |
| model_works=survey["model_works"], |
| image_visualization=survey["image_visualization"], |
| prompt_input=survey["prompt_input"], |
| processing_speed=survey["processing_speed"], |
| accessibility=survey["accessibility"], |
| overall_experience=survey["overall_experience"], |
| )) |
| ses.commit() |
| return {"status": "ok", "kind": "survey", "id": row_id, "created_at": ts} |
|
|
| |
| with Session(engine) as ses: |
| ses.add(Feedback( |
| id=row_id, |
| created_at=ts, |
| name=payload.name, |
| email=str(payload.email) if payload.email else None, |
| feedback=payload.feedback |
| )) |
| ses.commit() |
|
|
| |
| try: |
| _send_mail( |
| to=CONTACT_RECIPIENTS, |
| subject="ConeImage: New Feedback", |
| body=( |
| f"Name: {payload.name or '(anonymous)'}\n" |
| f"Email: {payload.email or '(none provided)'}\n" |
| f"Message:\n{payload.feedback}\n\n" |
| f"Submitted at: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ts))} UTC" |
| ), |
| reply_to=None, |
| ) |
| except Exception: |
| pass |
|
|
| return {"status": "ok", "kind": "feedback", "id": row_id, "created_at": ts} |
|
|