import os import io import time import uuid import json import hmac import hashlib import secrets from typing import List, Optional, Any, Dict from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Form, Body from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, EmailStr, Field from PIL import Image from sqlalchemy import Column, String, Integer, Float, create_engine, text, select from sqlalchemy.orm import declarative_base, Session from email.message import EmailMessage import smtplib from model import load_model, segment_image, load_model_variant from utils import colorize_mask, overlay_mask, build_legend_from_mask # -------------------- Config -------------------- MEDIA_ROOT = os.getenv("MEDIA_ROOT", "/data/media") UPLOAD_DIR = os.path.join(MEDIA_ROOT, "uploads") # raw uploads go here DB_PATH = os.getenv("DB_PATH", "/data/app.db") DEVICE = os.getenv("DEVICE", "cpu") CORS_ORIGINS = [ o.strip() for o in os.getenv( "CORS_ORIGINS", "https://coneimage.com,https://www.coneimage.com,http://localhost:3000,http://127.0.0.1:3000", ).split(",") if o.strip() ] os.makedirs(MEDIA_ROOT, exist_ok=True) os.makedirs(UPLOAD_DIR, exist_ok=True) MAX_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(10 * 1024 * 1024))) # 10 MB SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7))) EMAIL_CODE_TTL_SECONDS = int(os.getenv("EMAIL_CODE_TTL_SECONDS", str(15 * 60))) PASSWORD_RESET_TTL_SECONDS = int(os.getenv("PASSWORD_RESET_TTL_SECONDS", str(15 * 60))) # SMTP / Email config SMTP_HOST = os.getenv("SMTP_HOST", "smtp.gmail.com") SMTP_PORT = int(os.getenv("SMTP_PORT", "587")) SMTP_USER = os.getenv("SMTP_USER", "coneimage123@gmail.com") SMTP_PASS = os.getenv("SMTP_PASS", "elao hevi deat ehnn") SMTP_FROM = os.getenv("SMTP_FROM", "coneimage123@gmail.com") # Default recipients per request CONTACT_RECIPIENTS = [ e.strip() for e in os.getenv( "CONTACT_RECIPIENTS", "mahmed10.umbc@gmail.com,smxahsan@gmail.com", ).split(",") if e.strip() ] # -------------------- App -------------------- app = FastAPI(title="ConeImage API", version="1.5.0") app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS if CORS_ORIGINS else ["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], max_age=86400, ) # -------------------- DB -------------------- engine = create_engine(f"sqlite:///{DB_PATH}", future=True) Base = declarative_base() class Job(Base): __tablename__ = "jobs" id = Column(String, primary_key=True) created_at = Column(Float, index=True) input_path = Column(String) # file stored under MEDIA_ROOT (final job copy) input_name = Column(String) # ORIGINAL uploaded filename (what we show in history) mask_path = Column(String) overlay_path = Column(String) model = Column(String) device = Column(String) width = Column(Integer) height = Column(Integer) user_id = Column(String, index=True, nullable=True) class Upload(Base): """ Keeps metadata for uploaded files so we can recover the original filename in the two-step flow (POST /api/upload -> POST /api/segment with upload_id). """ __tablename__ = "uploads" upload_id = Column(String, primary_key=True) created_at = Column(Float, index=True) stored_name = Column(String) # "_input.jpg" under UPLOAD_DIR orig_name = Column(String) # original filename the user selected width = Column(Integer) height = Column(Integer) user_id = Column(String, index=True, nullable=True) class User(Base): __tablename__ = "users" id = Column(String, primary_key=True) created_at = Column(Float, index=True) email = Column(String, unique=True, index=True) password_hash = Column(String) name = Column(String, nullable=True) is_verified = Column(Integer, default=0) verified_at = Column(Float, nullable=True) preferred_model = Column(String, default="green") class SessionToken(Base): __tablename__ = "sessions" token = Column(String, primary_key=True) user_id = Column(String, index=True) created_at = Column(Float, index=True) expires_at = Column(Float, index=True) user_agent = Column(String, nullable=True) class Contact(Base): __tablename__ = "contacts" id = Column(String, primary_key=True) created_at = Column(Float, index=True) name = Column(String) email = Column(String) subject = Column(String, nullable=True) message = Column(String) class Feedback(Base): __tablename__ = "feedback" id = Column(String, primary_key=True) created_at = Column(Float, index=True) name = Column(String, nullable=True) email = Column(String, nullable=True) # NOTE: old deployments may still have a 'rating' column; we no longer use it. feedback = Column(String) # free-text feedback OR other notes class Survey(Base): """ Stores star-based site survey results (no notes). Ratings use a 1..5 scale where 1=worse, 5=best. """ __tablename__ = "surveys" id = Column(String, primary_key=True) created_at = Column(Float, index=True) role = Column(String, nullable=True) use_case = Column(String, nullable=True) model_works = Column(Integer, nullable=True) image_visualization = Column(Integer, nullable=True) prompt_input = Column(Integer, nullable=True) processing_speed = Column(Integer, nullable=True) accessibility = Column(Integer, nullable=True) overall_experience = Column(Integer, nullable=True) class EmailCode(Base): __tablename__ = "email_codes" id = Column(String, primary_key=True) user_id = Column(String, index=True) purpose = Column(String, index=True) code_hash = Column(String) code_salt = Column(String) created_at = Column(Float, index=True) expires_at = Column(Float, index=True) consumed_at = Column(Float, nullable=True) Base.metadata.create_all(engine) def _maybe_migrate_jobs_add_input_name(): """ Add 'input_name' column to jobs if missing (SQLite). Safe to run on every startup. """ with engine.connect() as conn: cols = conn.execute(text("PRAGMA table_info(jobs)")).all() existing_cols = {c[1] for c in cols} # c[1] is the name if "input_name" not in existing_cols: conn.execute(text("ALTER TABLE jobs ADD COLUMN input_name TEXT")) conn.commit() _maybe_migrate_jobs_add_input_name() def _maybe_migrate_contacts_add_subject(): with engine.connect() as conn: cols = conn.execute(text("PRAGMA table_info(contacts)")).all() existing_cols = {c[1] for c in cols} if "subject" not in existing_cols: conn.execute(text("ALTER TABLE contacts ADD COLUMN subject TEXT")) conn.commit() def _maybe_migrate_jobs_add_user_id(): with engine.connect() as conn: cols = conn.execute(text("PRAGMA table_info(jobs)")).all() existing_cols = {c[1] for c in cols} if "user_id" not in existing_cols: conn.execute(text("ALTER TABLE jobs ADD COLUMN user_id TEXT")) conn.commit() def _maybe_migrate_uploads_add_user_id(): with engine.connect() as conn: cols = conn.execute(text("PRAGMA table_info(uploads)")).all() existing_cols = {c[1] for c in cols} if "user_id" not in existing_cols: conn.execute(text("ALTER TABLE uploads ADD COLUMN user_id TEXT")) conn.commit() def _maybe_migrate_users_add_verification(): with engine.connect() as conn: cols = conn.execute(text("PRAGMA table_info(users)")).all() existing_cols = {c[1] for c in cols} if "is_verified" not in existing_cols: conn.execute(text("ALTER TABLE users ADD COLUMN is_verified INTEGER DEFAULT 0")) conn.commit() if "verified_at" not in existing_cols: conn.execute(text("ALTER TABLE users ADD COLUMN verified_at REAL")) conn.commit() def _maybe_migrate_users_add_preferred_model(): with engine.connect() as conn: cols = conn.execute(text("PRAGMA table_info(users)")).all() existing_cols = {c[1] for c in cols} if "preferred_model" not in existing_cols: conn.execute(text("ALTER TABLE users ADD COLUMN preferred_model TEXT")) conn.execute(text("UPDATE users SET preferred_model = 'green' WHERE preferred_model IS NULL")) conn.commit() _maybe_migrate_jobs_add_user_id() _maybe_migrate_uploads_add_user_id() _maybe_migrate_users_add_verification() _maybe_migrate_users_add_preferred_model() def _hash_password(password: str) -> str: salt = secrets.token_bytes(16) hashed = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100_000) return f"{salt.hex()}:{hashed.hex()}" def _verify_password(password: str, stored: str) -> bool: try: salt_hex, hash_hex = stored.split(":", 1) except ValueError: return False try: salt = bytes.fromhex(salt_hex) expected = bytes.fromhex(hash_hex) except ValueError: return False computed = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100_000) return hmac.compare_digest(computed, expected) def _create_session(user: User, ses: Session, user_agent: Optional[str] = None) -> str: token = secrets.token_urlsafe(48) now = time.time() ses.add( SessionToken( token=token, user_id=user.id, created_at=now, expires_at=now + SESSION_TTL_SECONDS, user_agent=user_agent, ) ) return token def _extract_bearer_token(request: Request) -> Optional[str]: auth = request.headers.get("Authorization") if not auth: return None if isinstance(auth, str) and auth.lower().startswith("bearer "): return auth[7:].strip() or None return None def _get_user_from_token(token: Optional[str], ses: Session, *, require: bool = False) -> Optional[User]: if not token: if require: raise HTTPException(status_code=401, detail="Missing authorization") return None session_row = ses.get(SessionToken, token) if not session_row or (session_row.expires_at and session_row.expires_at < time.time()): if require: raise HTTPException(status_code=401, detail="Invalid or expired session") return None user = ses.get(User, session_row.user_id) if not user: if require: raise HTTPException(status_code=401, detail="Invalid session") return None return user def _serialize_user(user: User) -> Dict[str, Any]: return { "id": user.id, "email": user.email, "name": user.name, "created_at": user.created_at, "is_verified": bool(user.is_verified), "verified_at": user.verified_at, "preferred_model": _normalize_model_key(getattr(user, "preferred_model", None)), } _maybe_migrate_contacts_add_subject() def _generate_numeric_code(length: int = 6) -> str: alphabet = "0123456789" return "".join(secrets.choice(alphabet) for _ in range(length)) def _hash_code(code: str, salt: str) -> str: return hashlib.sha256(f"{code}:{salt}".encode("utf-8")).hexdigest() def _issue_email_code(ses: Session, user: User, purpose: str, ttl_seconds: int) -> str: code = _generate_numeric_code() salt = secrets.token_hex(8) now = time.time() expires = now + ttl_seconds active = ses.scalars( select(EmailCode).where( EmailCode.user_id == user.id, EmailCode.purpose == purpose, EmailCode.consumed_at.is_(None), ) ).all() for entry in active: entry.consumed_at = entry.consumed_at or now ses.add( EmailCode( id=str(uuid.uuid4()), user_id=user.id, purpose=purpose, code_hash=_hash_code(code, salt), code_salt=salt, created_at=now, expires_at=expires, consumed_at=None, ) ) return code def _verify_email_code(ses: Session, user: User, purpose: str, code: str) -> Optional[EmailCode]: now = time.time() candidates = ses.scalars( select(EmailCode).where( EmailCode.user_id == user.id, EmailCode.purpose == purpose, EmailCode.consumed_at.is_(None), EmailCode.expires_at > now, ) ).all() for candidate in candidates: expected = _hash_code(code, candidate.code_salt) if hmac.compare_digest(expected, candidate.code_hash): candidate.consumed_at = now for other in candidates: if other.id != candidate.id: other.consumed_at = other.consumed_at or now return candidate return None def _send_verification_email(user: User, code: str) -> None: subject = "ConeImage verification code" minutes = max(1, EMAIL_CODE_TTL_SECONDS // 60) body = ( "Hello!\n\n" "Use the verification code below to activate your ConeImage account.\n\n" f"Code: {code}\n\n" f"This code will expire in {minutes} minute{'s' if minutes != 1 else ''}. If you did not sign up, you can ignore this email." ) _send_mail([user.email], subject, body) def _send_password_reset_email(user: User, code: str) -> None: subject = "ConeImage password reset" minutes = max(1, PASSWORD_RESET_TTL_SECONDS // 60) body = ( "We received a request to reset your ConeImage password.\n\n" f"Use this code to set a new password: {code}\n\n" f"This code will expire in {minutes} minute{'s' if minutes != 1 else ''}. If you did not request a reset, you can delete this email." ) _send_mail([user.email], subject, body) # -------------------- Model -------------------- DEFAULT_MODEL_KEY = "green" DEFAULT_MODEL, SENTENCE_MODEL, PCA_PARAMS = load_model(device=DEVICE) MODEL_REGISTRY: Dict[str, Any] = {DEFAULT_MODEL_KEY: DEFAULT_MODEL} def _normalize_model_key(key: Optional[str]) -> str: return "orange" if key == "orange" else DEFAULT_MODEL_KEY def _ensure_model_loaded(key: Optional[str]): normalized = _normalize_model_key(key) model = MODEL_REGISTRY.get(normalized) if model is None: model = load_model_variant(normalized, device=DEVICE) MODEL_REGISTRY[normalized] = model return model, normalized # -------------------- Static media -------------------- # Name MUST be "media" (used by request.url_for below) app.mount("/media", StaticFiles(directory=MEDIA_ROOT), name="media") # -------------------- Schemas -------------------- class JobOut(BaseModel): id: str created_at: float input_url: str mask_url: str overlay_url: str model: str device: str width: int height: int input_name: str legend: list[dict] | None = None class UploadOut(BaseModel): upload_id: str created_at: float input_url: str width: int height: int class ContactIn(BaseModel): name: str = Field(min_length=1, max_length=120) email: EmailStr subject: Optional[str] = Field(default=None, max_length=200) # <-- NEW message: str = Field(min_length=1, max_length=5000) class FeedbackIn(BaseModel): # Rating removed per request name: Optional[str] = Field(default=None, max_length=120) email: Optional[EmailStr] = None feedback: str = Field(min_length=1, max_length=10000) class UserOut(BaseModel): id: str email: EmailStr name: Optional[str] = None created_at: float is_verified: bool verified_at: Optional[float] = None preferred_model: str class AuthResponse(BaseModel): token: str user: UserOut class RegisterIn(BaseModel): email: EmailStr password: str = Field(min_length=8, max_length=256) name: str = Field(min_length=1, max_length=120) class RegisterOut(BaseModel): detail: str requires_verification: bool = True class LoginIn(BaseModel): email: EmailStr password: str = Field(min_length=1, max_length=256) class AuthMeOut(BaseModel): user: UserOut class VerifyResendIn(BaseModel): email: EmailStr class VerifyConfirmIn(BaseModel): email: EmailStr code: str = Field(min_length=6, max_length=6, pattern=r"^\d{6}$") class GenericMessageOut(BaseModel): detail: str class PasswordForgotIn(BaseModel): email: EmailStr class PasswordResetIn(BaseModel): email: EmailStr code: str = Field(min_length=6, max_length=6, pattern=r"^\d{6}$") new_password: str = Field(min_length=8, max_length=256) # -------------------- Helpers -------------------- def _send_mail(to: list[str], subject: str, body: str, reply_to: Optional[str] = None) -> bool: """Best-effort SMTP email. Returns True on success, False otherwise.""" if not SMTP_HOST or not to: return False msg = EmailMessage() msg["Subject"] = subject msg["From"] = SMTP_FROM msg["To"] = ", ".join(to) if reply_to: msg["Reply-To"] = reply_to msg.set_content(body) try: with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=15) as s: s.starttls() if SMTP_USER: s.login(SMTP_USER, SMTP_PASS) s.send_message(msg) return True except Exception: return False def _extract_int(v: Any) -> Optional[int]: try: iv = int(v) if 1 <= iv <= 5: return iv except Exception: pass return None def _try_parse_site_survey(raw_feedback: str) -> Optional[Dict[str, Any]]: """ If the 'feedback' field contains a JSON string with tag='SITE_SURVEY', parse and return a dict with normalized keys. Otherwise return None. """ try: data = json.loads(raw_feedback) if not isinstance(data, dict): return None if str(data.get("tag")).upper() != "SITE_SURVEY": return None ratings = data.get("ratings") or {} return { "role": (data.get("role") or None), "use_case": (data.get("use_case") or None), "model_works": _extract_int(ratings.get("model_works")), "image_visualization": _extract_int(ratings.get("image_visualization")), "prompt_input": _extract_int(ratings.get("prompt_input")), "processing_speed": _extract_int(ratings.get("processing_speed")), "accessibility": _extract_int(ratings.get("accessibility")), "overall_experience": _extract_int(ratings.get("overall_experience")), } except Exception: return None # -------------------- Middleware -------------------- @app.middleware("http") async def add_cache_headers(request: Request, call_next): resp = await call_next(request) if request.url.path.startswith("/media/"): resp.headers["Cache-Control"] = "public, max-age=2592000, immutable" return resp # -------------------- Routes -------------------- @app.get("/api/health") def health(): return {"status": "ok", "device": DEVICE, "model_key": DEFAULT_MODEL_KEY} @app.post("/api/auth/register", response_model=RegisterOut, status_code=201) def register_user(payload: RegisterIn): email_norm = str(payload.email).strip().lower() name_norm = payload.name.strip() if not name_norm: raise HTTPException(status_code=422, detail="Name must not be empty") with Session(engine) as ses: existing = ses.scalar(select(User).where(User.email == email_norm)) if existing: raise HTTPException(status_code=409, detail="Email already registered") user = User( id=str(uuid.uuid4()), created_at=time.time(), email=email_norm, password_hash=_hash_password(payload.password), name=name_norm, is_verified=0, verified_at=None, preferred_model=DEFAULT_MODEL_KEY, ) ses.add(user) ses.flush() code = _issue_email_code(ses, user, "verify_email", EMAIL_CODE_TTL_SECONDS) ses.commit() _send_verification_email(user, code) return RegisterOut( detail="Account created. Check your email for a verification code to activate it.", requires_verification=True, ) @app.post("/api/auth/login", response_model=AuthResponse) def login_user(payload: LoginIn, request: Request): email_norm = str(payload.email).strip().lower() with Session(engine) as ses: user = ses.scalar(select(User).where(User.email == email_norm)) if not user or not _verify_password(payload.password, user.password_hash): raise HTTPException(status_code=401, detail="Invalid email or password") if not user.is_verified: raise HTTPException( status_code=403, detail="Email not verified. Use the code sent to your inbox to activate your account.", ) token = _create_session(user, ses, request.headers.get("User-Agent")) ses.commit() return AuthResponse(token=token, user=UserOut(**_serialize_user(user))) @app.post("/api/auth/verify/resend", response_model=GenericMessageOut) def resend_verification(payload: VerifyResendIn): email_norm = str(payload.email).strip().lower() with Session(engine) as ses: user = ses.scalar(select(User).where(User.email == email_norm)) if not user: return GenericMessageOut(detail="If an account exists, a verification email has been sent.") if user.is_verified: return GenericMessageOut(detail="This email is already verified. You can sign in.") code = _issue_email_code(ses, user, "verify_email", EMAIL_CODE_TTL_SECONDS) ses.commit() _send_verification_email(user, code) return GenericMessageOut(detail="Verification code sent. Check your email.") @app.post("/api/auth/verify/confirm", response_model=AuthResponse) def confirm_verification(payload: VerifyConfirmIn, request: Request): email_norm = str(payload.email).strip().lower() code = payload.code.strip() with Session(engine) as ses: user = ses.scalar(select(User).where(User.email == email_norm)) if not user: raise HTTPException(status_code=400, detail="Invalid or expired verification code") match = _verify_email_code(ses, user, "verify_email", code) if not match: raise HTTPException(status_code=400, detail="Invalid or expired verification code") user.is_verified = 1 user.verified_at = user.verified_at or time.time() token = _create_session(user, ses, request.headers.get("User-Agent")) ses.commit() return AuthResponse(token=token, user=UserOut(**_serialize_user(user))) @app.post("/api/auth/password/forgot", response_model=GenericMessageOut) def forgot_password(payload: PasswordForgotIn): email_norm = str(payload.email).strip().lower() with Session(engine) as ses: user = ses.scalar(select(User).where(User.email == email_norm)) if not user or not user.is_verified: # Avoid account discovery return GenericMessageOut(detail="If an account exists, you'll receive a reset code shortly.") code = _issue_email_code(ses, user, "reset_password", PASSWORD_RESET_TTL_SECONDS) ses.commit() _send_password_reset_email(user, code) return GenericMessageOut(detail="Check your email for a password reset code.") @app.post("/api/auth/password/reset", response_model=GenericMessageOut) def reset_password(payload: PasswordResetIn): email_norm = str(payload.email).strip().lower() code = payload.code.strip() with Session(engine) as ses: user = ses.scalar(select(User).where(User.email == email_norm)) if not user: raise HTTPException(status_code=400, detail="Invalid or expired reset code") match = _verify_email_code(ses, user, "reset_password", code) if not match: raise HTTPException(status_code=400, detail="Invalid or expired reset code") user.password_hash = _hash_password(payload.new_password) if not user.is_verified: user.is_verified = 1 user.verified_at = time.time() ses.execute(text("DELETE FROM sessions WHERE user_id = :uid"), {"uid": user.id}) ses.commit() return GenericMessageOut(detail="Password updated. You can now sign in with your new password.") @app.get("/api/auth/me", response_model=AuthMeOut) def whoami(request: Request): token = _extract_bearer_token(request) with Session(engine) as ses: user = _get_user_from_token(token, ses, require=True) return AuthMeOut(user=UserOut(**_serialize_user(user))) @app.post("/api/auth/logout") def logout_user(request: Request): token = _extract_bearer_token(request) if not token: raise HTTPException(status_code=401, detail="Missing authorization") with Session(engine) as ses: sess = ses.get(SessionToken, token) if not sess: raise HTTPException(status_code=401, detail="Invalid session") ses.delete(sess) ses.commit() return {"status": "ok"} @app.post("/api/model/select") def select_model(request: Request, model_key: str = Form(...)): """ Update the preferred segmentation model for the authenticated user. Accepts 'green' or 'orange'. Anything else falls back to 'green'. """ token = _extract_bearer_token(request) with Session(engine) as ses: user = _get_user_from_token(token, ses, require=True) normalized = _normalize_model_key(model_key) user.preferred_model = normalized ses.commit() _ensure_model_loaded(normalized) return {"status": "ok", "model_key": normalized, "device": DEVICE} @app.get("/api/model/current") def current_model(request: Request): token = _extract_bearer_token(request) with Session(engine) as ses: user = _get_user_from_token(token, ses) key = _normalize_model_key(user.preferred_model if user else None) return {"model_key": key, "device": DEVICE} @app.post("/api/upload", response_model=UploadOut) async def upload(request: Request, file: UploadFile = File(...)): # Read & validate image raw = await file.read() if len(raw) > MAX_BYTES: raise HTTPException(status_code=413, detail="Image too large (max 10MB)") try: with Image.open(io.BytesIO(raw)) as probe: probe.verify() img = Image.open(io.BytesIO(raw)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image") w, h = img.size upload_id = str(uuid.uuid4()) ts = time.time() stored_name = f"{upload_id}_input.jpg" in_path = os.path.join(UPLOAD_DIR, stored_name) img.save(in_path, quality=92, optimize=True) # persist upload metadata (original filename) with Session(engine) as ses: user = _get_user_from_token(_extract_bearer_token(request), ses) user_id = user.id if user else None ses.add( Upload( upload_id=upload_id, created_at=ts, stored_name=stored_name, orig_name=os.path.basename(file.filename) if file.filename else "unknown", width=w, height=h, user_id=user_id, ) ) ses.commit() input_url = request.url_for("media", path=f"uploads/{stored_name}") return UploadOut( upload_id=upload_id, created_at=ts, input_url=str(input_url), width=w, height=h, ) @app.post("/api/segment", response_model=JobOut) async def segment( request: Request, file: Optional[UploadFile] = File(None), # backward-compat upload_id: Optional[str] = Form(None), # preferred path prompt: Optional[str] = Form(None), # optional text input ): """ Preferred: pass upload_id (image already uploaded via /api/upload). Backward compatibility: can still post a 'file'. If 'prompt' is provided, segmentation will be prompt-guided. """ # --- Normalize prompt into list, prepend "background" if non-empty --- if prompt: prompt_list = [p.strip() for p in prompt.split(",") if p.strip()] if prompt_list: if prompt_list[0].lower() != "background": prompt_list = ["background"] + prompt_list else: prompt_list = None else: prompt_list = None token = _extract_bearer_token(request) user_id: Optional[str] = None user_model_key = DEFAULT_MODEL_KEY if token: with Session(engine) as ses: user = _get_user_from_token(token, ses) if user: user_id = user.id user_model_key = _normalize_model_key(getattr(user, "preferred_model", None)) # Load the image & get original filename if upload_id and not file: # two-step path: fetch stored image + original name in_name = f"{upload_id}_input.jpg" in_path = os.path.join(UPLOAD_DIR, in_name) if not os.path.exists(in_path): raise HTTPException(status_code=404, detail="Upload not found") try: img = Image.open(in_path).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image") # get orig filename from uploads table with Session(engine) as ses: upl = ses.get(Upload, upload_id) if upl: if upl.user_id and upl.user_id != user_id: raise HTTPException(status_code=403, detail="Upload does not belong to this account") orig_name = upl.orig_name or in_name else: orig_name = in_name # fallback else: # one-step legacy path: handle direct file upload to /segment if not file: raise HTTPException(status_code=400, detail="No image or upload_id provided") raw = await file.read() if len(raw) > MAX_BYTES: raise HTTPException(status_code=413, detail="Image too large (max 10MB)") try: with Image.open(io.BytesIO(raw)) as probe: probe.verify() img = Image.open(io.BytesIO(raw)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image") orig_name = os.path.basename(file.filename) if file.filename else "unknown" w, h = img.size # Inference model, user_model_key = _ensure_model_loaded(user_model_key) try: if prompt_list: mask = segment_image(model, img, prompt_list, SENTENCE_MODEL, PCA_PARAMS, device=DEVICE) else: mask = segment_image(model, img, device=DEVICE) except Exception as e: raise HTTPException(status_code=500, detail=f"Inference failed: {e}") dataset = "ade" if user_model_key == "orange" else "coco" # Prompt-aware coloring & legend if prompt_list: color = colorize_mask(mask, classes=prompt_list) legend = build_legend_from_mask(mask, classes=prompt_list) else: color = colorize_mask(mask, dataset=dataset) legend = build_legend_from_mask(mask, dataset=dataset) over = overlay_mask(img, color, alpha=0.45) # Save artifacts job_id = str(uuid.uuid4()) ts = time.time() in_name_final = f"{job_id}_input.jpg" mask_name = f"{job_id}_mask.png" over_name = f"{job_id}_overlay.jpg" legend_name = f"{job_id}_legend.json" in_path_final = os.path.join(MEDIA_ROOT, in_name_final) mask_path = os.path.join(MEDIA_ROOT, mask_name) over_path = os.path.join(MEDIA_ROOT, over_name) legend_path = os.path.join(MEDIA_ROOT, legend_name) img.save(in_path_final, quality=92, optimize=True) color.save(mask_path) over.save(over_path, quality=92, optimize=True) try: with open(legend_path, "w", encoding="utf-8") as f: json.dump(legend, f) except Exception: legend_path = None # Persist job row with Session(engine) as ses: ses.add( Job( id=job_id, created_at=ts, input_path=in_name_final, input_name=orig_name or "unknown", # <- show this in history mask_path=mask_name, overlay_path=over_name, model="Meruformerb4_entailonly", device=DEVICE, width=w, height=h, user_id=user_id, ) ) ses.commit() input_url = request.url_for("media", path=in_name_final) mask_url = request.url_for("media", path=mask_name) overlay_url = request.url_for("media", path=over_name) return JobOut( id=job_id, created_at=ts, input_url=str(input_url), mask_url=str(mask_url), overlay_url=str(overlay_url), model="Meruformerb4_entailonly", device=DEVICE, width=w, height=h, input_name=orig_name or "unknown", legend=legend, ) @app.get("/api/history", response_model=List[JobOut]) def history(request: Request, limit: int = 20, offset: int = 0): token = _extract_bearer_token(request) with Session(engine) as ses: user = _get_user_from_token(token, ses) params: Dict[str, Any] = {"limit": limit, "offset": offset} if user: query = text( """ SELECT id, created_at, input_path, mask_path, overlay_path, model, device, width, height, input_name FROM jobs WHERE user_id = :user_id ORDER BY created_at DESC LIMIT :limit OFFSET :offset """ ) params["user_id"] = user.id else: query = text( """ SELECT id, created_at, input_path, mask_path, overlay_path, model, device, width, height, input_name FROM jobs WHERE user_id IS NULL ORDER BY created_at DESC LIMIT :limit OFFSET :offset """ ) rows = ses.execute(query, params).all() out: List[JobOut] = [] for r in rows: job_id = r[0] legend = None sidecar = os.path.join(MEDIA_ROOT, f"{job_id}_legend.json") if os.path.exists(sidecar): try: with open(sidecar, "r", encoding="utf-8") as f: legend = json.load(f) except Exception: legend = None out.append( JobOut( id=job_id, created_at=r[1], input_url=str(request.url_for("media", path=r[2])), mask_url=str(request.url_for("media", path=r[3])), overlay_url=str(request.url_for("media", path=r[4])), model=r[5], device=r[6], width=r[7], height=r[8], input_name=r[9] or "unknown", legend=legend, # <— attach legend if available ) ) return out @app.delete("/api/history/{job_id}") def delete_job(job_id: str, request: Request): with Session(engine) as ses: user = _get_user_from_token(_extract_bearer_token(request), ses) row = ses.get(Job, job_id) if not row: raise HTTPException(status_code=404, detail="Job not found") if row.user_id: if not user: raise HTTPException(status_code=401, detail="Authentication required") if user.id != row.user_id: raise HTTPException(status_code=403, detail="Forbidden") # delete files too for path in [row.input_path, row.mask_path, row.overlay_path]: try: os.remove(os.path.join(MEDIA_ROOT, path)) except Exception: pass # delete legend sidecar try: os.remove(os.path.join(MEDIA_ROOT, f"{job_id}_legend.json")) except Exception: pass ses.delete(row) ses.commit() return {"status": "deleted", "id": job_id} @app.delete("/api/history") def clear_history(request: Request): with Session(engine) as ses: user = _get_user_from_token(_extract_bearer_token(request), ses) if user: rows = ses.query(Job).filter(Job.user_id == user.id).all() else: rows = ses.query(Job).filter(Job.user_id.is_(None)).all() for row in rows: for path in [row.input_path, row.mask_path, row.overlay_path]: try: os.remove(os.path.join(MEDIA_ROOT, path)) except Exception: pass try: os.remove(os.path.join(MEDIA_ROOT, f"{row.id}_legend.json")) except Exception: pass ses.delete(row) ses.commit() return {"status": "cleared"} # -------------------- Contact, Feedback (+Survey autodetect) -------------------- @app.post("/api/contact") def create_contact(payload: ContactIn): ts = time.time() row_id = str(uuid.uuid4()) with Session(engine) as ses: ses.add(Contact( id=row_id, created_at=ts, name=payload.name, email=str(payload.email), subject=payload.subject, # <-- save it message=payload.message )) ses.commit() # Build subject line subj = (payload.subject or "").strip() email_subject = f"ConeImage: {subj}" if subj else "ConeImage: No subject" # Email the submission; set Reply-To to submitter's email try: _send_mail( to=CONTACT_RECIPIENTS, subject=email_subject, # <-- use user's subject or fallback body=( f"Name: {payload.name}\n" f"Email: {payload.email}\n" f"Subject: {subj or 'No subject'}\n" f"Message:\n{payload.message}\n\n" f"Submitted at: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ts))} UTC" ), reply_to=str(payload.email), ) except Exception: pass return {"status": "ok", "id": row_id, "created_at": ts} @app.post("/api/feedback") def create_feedback(payload: FeedbackIn): """ Accepts plain feedback (no rating) OR a survey JSON (tag='SITE_SURVEY') posted as the 'feedback' string. If a survey is detected, it is saved into the 'surveys' table; otherwise, it is saved into 'feedback'. For plain feedback, send an email that includes only the submitter's email. """ ts = time.time() row_id = str(uuid.uuid4()) # Detect and store site survey if present (no email required) survey = _try_parse_site_survey(payload.feedback) if survey is not None: with Session(engine) as ses: ses.add(Survey( id=row_id, created_at=ts, role=survey["role"], use_case=survey["use_case"], model_works=survey["model_works"], image_visualization=survey["image_visualization"], prompt_input=survey["prompt_input"], processing_speed=survey["processing_speed"], accessibility=survey["accessibility"], overall_experience=survey["overall_experience"], )) ses.commit() return {"status": "ok", "kind": "survey", "id": row_id, "created_at": ts} # Otherwise store as regular feedback (no rating field) with Session(engine) as ses: ses.add(Feedback( id=row_id, created_at=ts, name=payload.name, email=str(payload.email) if payload.email else None, feedback=payload.feedback )) ses.commit() # Email only the submitter's email address to recipients (per request) try: _send_mail( to=CONTACT_RECIPIENTS, subject="ConeImage: New Feedback", body=( f"Name: {payload.name or '(anonymous)'}\n" f"Email: {payload.email or '(none provided)'}\n" f"Message:\n{payload.feedback}\n\n" f"Submitted at: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ts))} UTC" ), reply_to=None, ) except Exception: pass return {"status": "ok", "kind": "feedback", "id": row_id, "created_at": ts}