seg-app / api /main.py
mahmed10's picture
Upload 55 files
19d78dd verified
import os
import io
import time
import uuid
import json
import hmac
import hashlib
import secrets
from typing import List, Optional, Any, Dict
from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Form, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, EmailStr, Field
from PIL import Image
from sqlalchemy import Column, String, Integer, Float, create_engine, text, select
from sqlalchemy.orm import declarative_base, Session
from email.message import EmailMessage
import smtplib
from model import load_model, segment_image, load_model_variant
from utils import colorize_mask, overlay_mask, build_legend_from_mask
# -------------------- Config --------------------
MEDIA_ROOT = os.getenv("MEDIA_ROOT", "/data/media")
UPLOAD_DIR = os.path.join(MEDIA_ROOT, "uploads") # raw uploads go here
DB_PATH = os.getenv("DB_PATH", "/data/app.db")
DEVICE = os.getenv("DEVICE", "cpu")
CORS_ORIGINS = [
o.strip()
for o in os.getenv(
"CORS_ORIGINS",
"https://coneimage.com,https://www.coneimage.com,http://localhost:3000,http://127.0.0.1:3000",
).split(",")
if o.strip()
]
os.makedirs(MEDIA_ROOT, exist_ok=True)
os.makedirs(UPLOAD_DIR, exist_ok=True)
MAX_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", str(10 * 1024 * 1024))) # 10 MB
SESSION_TTL_SECONDS = int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 24 * 7)))
EMAIL_CODE_TTL_SECONDS = int(os.getenv("EMAIL_CODE_TTL_SECONDS", str(15 * 60)))
PASSWORD_RESET_TTL_SECONDS = int(os.getenv("PASSWORD_RESET_TTL_SECONDS", str(15 * 60)))
# SMTP / Email config
SMTP_HOST = os.getenv("SMTP_HOST", "smtp.gmail.com")
SMTP_PORT = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER = os.getenv("SMTP_USER", "coneimage123@gmail.com")
SMTP_PASS = os.getenv("SMTP_PASS", "elao hevi deat ehnn")
SMTP_FROM = os.getenv("SMTP_FROM", "coneimage123@gmail.com")
# Default recipients per request
CONTACT_RECIPIENTS = [
e.strip()
for e in os.getenv(
"CONTACT_RECIPIENTS",
"mahmed10.umbc@gmail.com,smxahsan@gmail.com",
).split(",")
if e.strip()
]
# -------------------- App --------------------
app = FastAPI(title="ConeImage API", version="1.5.0")
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS if CORS_ORIGINS else ["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
max_age=86400,
)
# -------------------- DB --------------------
engine = create_engine(f"sqlite:///{DB_PATH}", future=True)
Base = declarative_base()
class Job(Base):
__tablename__ = "jobs"
id = Column(String, primary_key=True)
created_at = Column(Float, index=True)
input_path = Column(String) # file stored under MEDIA_ROOT (final job copy)
input_name = Column(String) # ORIGINAL uploaded filename (what we show in history)
mask_path = Column(String)
overlay_path = Column(String)
model = Column(String)
device = Column(String)
width = Column(Integer)
height = Column(Integer)
user_id = Column(String, index=True, nullable=True)
class Upload(Base):
"""
Keeps metadata for uploaded files so we can recover the original filename
in the two-step flow (POST /api/upload -> POST /api/segment with upload_id).
"""
__tablename__ = "uploads"
upload_id = Column(String, primary_key=True)
created_at = Column(Float, index=True)
stored_name = Column(String) # "<upload_id>_input.jpg" under UPLOAD_DIR
orig_name = Column(String) # original filename the user selected
width = Column(Integer)
height = Column(Integer)
user_id = Column(String, index=True, nullable=True)
class User(Base):
__tablename__ = "users"
id = Column(String, primary_key=True)
created_at = Column(Float, index=True)
email = Column(String, unique=True, index=True)
password_hash = Column(String)
name = Column(String, nullable=True)
is_verified = Column(Integer, default=0)
verified_at = Column(Float, nullable=True)
preferred_model = Column(String, default="green")
class SessionToken(Base):
__tablename__ = "sessions"
token = Column(String, primary_key=True)
user_id = Column(String, index=True)
created_at = Column(Float, index=True)
expires_at = Column(Float, index=True)
user_agent = Column(String, nullable=True)
class Contact(Base):
__tablename__ = "contacts"
id = Column(String, primary_key=True)
created_at = Column(Float, index=True)
name = Column(String)
email = Column(String)
subject = Column(String, nullable=True)
message = Column(String)
class Feedback(Base):
__tablename__ = "feedback"
id = Column(String, primary_key=True)
created_at = Column(Float, index=True)
name = Column(String, nullable=True)
email = Column(String, nullable=True)
# NOTE: old deployments may still have a 'rating' column; we no longer use it.
feedback = Column(String) # free-text feedback OR other notes
class Survey(Base):
"""
Stores star-based site survey results (no notes).
Ratings use a 1..5 scale where 1=worse, 5=best.
"""
__tablename__ = "surveys"
id = Column(String, primary_key=True)
created_at = Column(Float, index=True)
role = Column(String, nullable=True)
use_case = Column(String, nullable=True)
model_works = Column(Integer, nullable=True)
image_visualization = Column(Integer, nullable=True)
prompt_input = Column(Integer, nullable=True)
processing_speed = Column(Integer, nullable=True)
accessibility = Column(Integer, nullable=True)
overall_experience = Column(Integer, nullable=True)
class EmailCode(Base):
__tablename__ = "email_codes"
id = Column(String, primary_key=True)
user_id = Column(String, index=True)
purpose = Column(String, index=True)
code_hash = Column(String)
code_salt = Column(String)
created_at = Column(Float, index=True)
expires_at = Column(Float, index=True)
consumed_at = Column(Float, nullable=True)
Base.metadata.create_all(engine)
def _maybe_migrate_jobs_add_input_name():
"""
Add 'input_name' column to jobs if missing (SQLite).
Safe to run on every startup.
"""
with engine.connect() as conn:
cols = conn.execute(text("PRAGMA table_info(jobs)")).all()
existing_cols = {c[1] for c in cols} # c[1] is the name
if "input_name" not in existing_cols:
conn.execute(text("ALTER TABLE jobs ADD COLUMN input_name TEXT"))
conn.commit()
_maybe_migrate_jobs_add_input_name()
def _maybe_migrate_contacts_add_subject():
with engine.connect() as conn:
cols = conn.execute(text("PRAGMA table_info(contacts)")).all()
existing_cols = {c[1] for c in cols}
if "subject" not in existing_cols:
conn.execute(text("ALTER TABLE contacts ADD COLUMN subject TEXT"))
conn.commit()
def _maybe_migrate_jobs_add_user_id():
with engine.connect() as conn:
cols = conn.execute(text("PRAGMA table_info(jobs)")).all()
existing_cols = {c[1] for c in cols}
if "user_id" not in existing_cols:
conn.execute(text("ALTER TABLE jobs ADD COLUMN user_id TEXT"))
conn.commit()
def _maybe_migrate_uploads_add_user_id():
with engine.connect() as conn:
cols = conn.execute(text("PRAGMA table_info(uploads)")).all()
existing_cols = {c[1] for c in cols}
if "user_id" not in existing_cols:
conn.execute(text("ALTER TABLE uploads ADD COLUMN user_id TEXT"))
conn.commit()
def _maybe_migrate_users_add_verification():
with engine.connect() as conn:
cols = conn.execute(text("PRAGMA table_info(users)")).all()
existing_cols = {c[1] for c in cols}
if "is_verified" not in existing_cols:
conn.execute(text("ALTER TABLE users ADD COLUMN is_verified INTEGER DEFAULT 0"))
conn.commit()
if "verified_at" not in existing_cols:
conn.execute(text("ALTER TABLE users ADD COLUMN verified_at REAL"))
conn.commit()
def _maybe_migrate_users_add_preferred_model():
with engine.connect() as conn:
cols = conn.execute(text("PRAGMA table_info(users)")).all()
existing_cols = {c[1] for c in cols}
if "preferred_model" not in existing_cols:
conn.execute(text("ALTER TABLE users ADD COLUMN preferred_model TEXT"))
conn.execute(text("UPDATE users SET preferred_model = 'green' WHERE preferred_model IS NULL"))
conn.commit()
_maybe_migrate_jobs_add_user_id()
_maybe_migrate_uploads_add_user_id()
_maybe_migrate_users_add_verification()
_maybe_migrate_users_add_preferred_model()
def _hash_password(password: str) -> str:
salt = secrets.token_bytes(16)
hashed = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100_000)
return f"{salt.hex()}:{hashed.hex()}"
def _verify_password(password: str, stored: str) -> bool:
try:
salt_hex, hash_hex = stored.split(":", 1)
except ValueError:
return False
try:
salt = bytes.fromhex(salt_hex)
expected = bytes.fromhex(hash_hex)
except ValueError:
return False
computed = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100_000)
return hmac.compare_digest(computed, expected)
def _create_session(user: User, ses: Session, user_agent: Optional[str] = None) -> str:
token = secrets.token_urlsafe(48)
now = time.time()
ses.add(
SessionToken(
token=token,
user_id=user.id,
created_at=now,
expires_at=now + SESSION_TTL_SECONDS,
user_agent=user_agent,
)
)
return token
def _extract_bearer_token(request: Request) -> Optional[str]:
auth = request.headers.get("Authorization")
if not auth:
return None
if isinstance(auth, str) and auth.lower().startswith("bearer "):
return auth[7:].strip() or None
return None
def _get_user_from_token(token: Optional[str], ses: Session, *, require: bool = False) -> Optional[User]:
if not token:
if require:
raise HTTPException(status_code=401, detail="Missing authorization")
return None
session_row = ses.get(SessionToken, token)
if not session_row or (session_row.expires_at and session_row.expires_at < time.time()):
if require:
raise HTTPException(status_code=401, detail="Invalid or expired session")
return None
user = ses.get(User, session_row.user_id)
if not user:
if require:
raise HTTPException(status_code=401, detail="Invalid session")
return None
return user
def _serialize_user(user: User) -> Dict[str, Any]:
return {
"id": user.id,
"email": user.email,
"name": user.name,
"created_at": user.created_at,
"is_verified": bool(user.is_verified),
"verified_at": user.verified_at,
"preferred_model": _normalize_model_key(getattr(user, "preferred_model", None)),
}
_maybe_migrate_contacts_add_subject()
def _generate_numeric_code(length: int = 6) -> str:
alphabet = "0123456789"
return "".join(secrets.choice(alphabet) for _ in range(length))
def _hash_code(code: str, salt: str) -> str:
return hashlib.sha256(f"{code}:{salt}".encode("utf-8")).hexdigest()
def _issue_email_code(ses: Session, user: User, purpose: str, ttl_seconds: int) -> str:
code = _generate_numeric_code()
salt = secrets.token_hex(8)
now = time.time()
expires = now + ttl_seconds
active = ses.scalars(
select(EmailCode).where(
EmailCode.user_id == user.id,
EmailCode.purpose == purpose,
EmailCode.consumed_at.is_(None),
)
).all()
for entry in active:
entry.consumed_at = entry.consumed_at or now
ses.add(
EmailCode(
id=str(uuid.uuid4()),
user_id=user.id,
purpose=purpose,
code_hash=_hash_code(code, salt),
code_salt=salt,
created_at=now,
expires_at=expires,
consumed_at=None,
)
)
return code
def _verify_email_code(ses: Session, user: User, purpose: str, code: str) -> Optional[EmailCode]:
now = time.time()
candidates = ses.scalars(
select(EmailCode).where(
EmailCode.user_id == user.id,
EmailCode.purpose == purpose,
EmailCode.consumed_at.is_(None),
EmailCode.expires_at > now,
)
).all()
for candidate in candidates:
expected = _hash_code(code, candidate.code_salt)
if hmac.compare_digest(expected, candidate.code_hash):
candidate.consumed_at = now
for other in candidates:
if other.id != candidate.id:
other.consumed_at = other.consumed_at or now
return candidate
return None
def _send_verification_email(user: User, code: str) -> None:
subject = "ConeImage verification code"
minutes = max(1, EMAIL_CODE_TTL_SECONDS // 60)
body = (
"Hello!\n\n"
"Use the verification code below to activate your ConeImage account.\n\n"
f"Code: {code}\n\n"
f"This code will expire in {minutes} minute{'s' if minutes != 1 else ''}. If you did not sign up, you can ignore this email."
)
_send_mail([user.email], subject, body)
def _send_password_reset_email(user: User, code: str) -> None:
subject = "ConeImage password reset"
minutes = max(1, PASSWORD_RESET_TTL_SECONDS // 60)
body = (
"We received a request to reset your ConeImage password.\n\n"
f"Use this code to set a new password: {code}\n\n"
f"This code will expire in {minutes} minute{'s' if minutes != 1 else ''}. If you did not request a reset, you can delete this email."
)
_send_mail([user.email], subject, body)
# -------------------- Model --------------------
DEFAULT_MODEL_KEY = "green"
DEFAULT_MODEL, SENTENCE_MODEL, PCA_PARAMS = load_model(device=DEVICE)
MODEL_REGISTRY: Dict[str, Any] = {DEFAULT_MODEL_KEY: DEFAULT_MODEL}
def _normalize_model_key(key: Optional[str]) -> str:
return "orange" if key == "orange" else DEFAULT_MODEL_KEY
def _ensure_model_loaded(key: Optional[str]):
normalized = _normalize_model_key(key)
model = MODEL_REGISTRY.get(normalized)
if model is None:
model = load_model_variant(normalized, device=DEVICE)
MODEL_REGISTRY[normalized] = model
return model, normalized
# -------------------- Static media --------------------
# Name MUST be "media" (used by request.url_for below)
app.mount("/media", StaticFiles(directory=MEDIA_ROOT), name="media")
# -------------------- Schemas --------------------
class JobOut(BaseModel):
id: str
created_at: float
input_url: str
mask_url: str
overlay_url: str
model: str
device: str
width: int
height: int
input_name: str
legend: list[dict] | None = None
class UploadOut(BaseModel):
upload_id: str
created_at: float
input_url: str
width: int
height: int
class ContactIn(BaseModel):
name: str = Field(min_length=1, max_length=120)
email: EmailStr
subject: Optional[str] = Field(default=None, max_length=200) # <-- NEW
message: str = Field(min_length=1, max_length=5000)
class FeedbackIn(BaseModel):
# Rating removed per request
name: Optional[str] = Field(default=None, max_length=120)
email: Optional[EmailStr] = None
feedback: str = Field(min_length=1, max_length=10000)
class UserOut(BaseModel):
id: str
email: EmailStr
name: Optional[str] = None
created_at: float
is_verified: bool
verified_at: Optional[float] = None
preferred_model: str
class AuthResponse(BaseModel):
token: str
user: UserOut
class RegisterIn(BaseModel):
email: EmailStr
password: str = Field(min_length=8, max_length=256)
name: str = Field(min_length=1, max_length=120)
class RegisterOut(BaseModel):
detail: str
requires_verification: bool = True
class LoginIn(BaseModel):
email: EmailStr
password: str = Field(min_length=1, max_length=256)
class AuthMeOut(BaseModel):
user: UserOut
class VerifyResendIn(BaseModel):
email: EmailStr
class VerifyConfirmIn(BaseModel):
email: EmailStr
code: str = Field(min_length=6, max_length=6, pattern=r"^\d{6}$")
class GenericMessageOut(BaseModel):
detail: str
class PasswordForgotIn(BaseModel):
email: EmailStr
class PasswordResetIn(BaseModel):
email: EmailStr
code: str = Field(min_length=6, max_length=6, pattern=r"^\d{6}$")
new_password: str = Field(min_length=8, max_length=256)
# -------------------- Helpers --------------------
def _send_mail(to: list[str], subject: str, body: str, reply_to: Optional[str] = None) -> bool:
"""Best-effort SMTP email. Returns True on success, False otherwise."""
if not SMTP_HOST or not to:
return False
msg = EmailMessage()
msg["Subject"] = subject
msg["From"] = SMTP_FROM
msg["To"] = ", ".join(to)
if reply_to:
msg["Reply-To"] = reply_to
msg.set_content(body)
try:
with smtplib.SMTP(SMTP_HOST, SMTP_PORT, timeout=15) as s:
s.starttls()
if SMTP_USER:
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
return True
except Exception:
return False
def _extract_int(v: Any) -> Optional[int]:
try:
iv = int(v)
if 1 <= iv <= 5:
return iv
except Exception:
pass
return None
def _try_parse_site_survey(raw_feedback: str) -> Optional[Dict[str, Any]]:
"""
If the 'feedback' field contains a JSON string with tag='SITE_SURVEY',
parse and return a dict with normalized keys. Otherwise return None.
"""
try:
data = json.loads(raw_feedback)
if not isinstance(data, dict):
return None
if str(data.get("tag")).upper() != "SITE_SURVEY":
return None
ratings = data.get("ratings") or {}
return {
"role": (data.get("role") or None),
"use_case": (data.get("use_case") or None),
"model_works": _extract_int(ratings.get("model_works")),
"image_visualization": _extract_int(ratings.get("image_visualization")),
"prompt_input": _extract_int(ratings.get("prompt_input")),
"processing_speed": _extract_int(ratings.get("processing_speed")),
"accessibility": _extract_int(ratings.get("accessibility")),
"overall_experience": _extract_int(ratings.get("overall_experience")),
}
except Exception:
return None
# -------------------- Middleware --------------------
@app.middleware("http")
async def add_cache_headers(request: Request, call_next):
resp = await call_next(request)
if request.url.path.startswith("/media/"):
resp.headers["Cache-Control"] = "public, max-age=2592000, immutable"
return resp
# -------------------- Routes --------------------
@app.get("/api/health")
def health():
return {"status": "ok", "device": DEVICE, "model_key": DEFAULT_MODEL_KEY}
@app.post("/api/auth/register", response_model=RegisterOut, status_code=201)
def register_user(payload: RegisterIn):
email_norm = str(payload.email).strip().lower()
name_norm = payload.name.strip()
if not name_norm:
raise HTTPException(status_code=422, detail="Name must not be empty")
with Session(engine) as ses:
existing = ses.scalar(select(User).where(User.email == email_norm))
if existing:
raise HTTPException(status_code=409, detail="Email already registered")
user = User(
id=str(uuid.uuid4()),
created_at=time.time(),
email=email_norm,
password_hash=_hash_password(payload.password),
name=name_norm,
is_verified=0,
verified_at=None,
preferred_model=DEFAULT_MODEL_KEY,
)
ses.add(user)
ses.flush()
code = _issue_email_code(ses, user, "verify_email", EMAIL_CODE_TTL_SECONDS)
ses.commit()
_send_verification_email(user, code)
return RegisterOut(
detail="Account created. Check your email for a verification code to activate it.",
requires_verification=True,
)
@app.post("/api/auth/login", response_model=AuthResponse)
def login_user(payload: LoginIn, request: Request):
email_norm = str(payload.email).strip().lower()
with Session(engine) as ses:
user = ses.scalar(select(User).where(User.email == email_norm))
if not user or not _verify_password(payload.password, user.password_hash):
raise HTTPException(status_code=401, detail="Invalid email or password")
if not user.is_verified:
raise HTTPException(
status_code=403,
detail="Email not verified. Use the code sent to your inbox to activate your account.",
)
token = _create_session(user, ses, request.headers.get("User-Agent"))
ses.commit()
return AuthResponse(token=token, user=UserOut(**_serialize_user(user)))
@app.post("/api/auth/verify/resend", response_model=GenericMessageOut)
def resend_verification(payload: VerifyResendIn):
email_norm = str(payload.email).strip().lower()
with Session(engine) as ses:
user = ses.scalar(select(User).where(User.email == email_norm))
if not user:
return GenericMessageOut(detail="If an account exists, a verification email has been sent.")
if user.is_verified:
return GenericMessageOut(detail="This email is already verified. You can sign in.")
code = _issue_email_code(ses, user, "verify_email", EMAIL_CODE_TTL_SECONDS)
ses.commit()
_send_verification_email(user, code)
return GenericMessageOut(detail="Verification code sent. Check your email.")
@app.post("/api/auth/verify/confirm", response_model=AuthResponse)
def confirm_verification(payload: VerifyConfirmIn, request: Request):
email_norm = str(payload.email).strip().lower()
code = payload.code.strip()
with Session(engine) as ses:
user = ses.scalar(select(User).where(User.email == email_norm))
if not user:
raise HTTPException(status_code=400, detail="Invalid or expired verification code")
match = _verify_email_code(ses, user, "verify_email", code)
if not match:
raise HTTPException(status_code=400, detail="Invalid or expired verification code")
user.is_verified = 1
user.verified_at = user.verified_at or time.time()
token = _create_session(user, ses, request.headers.get("User-Agent"))
ses.commit()
return AuthResponse(token=token, user=UserOut(**_serialize_user(user)))
@app.post("/api/auth/password/forgot", response_model=GenericMessageOut)
def forgot_password(payload: PasswordForgotIn):
email_norm = str(payload.email).strip().lower()
with Session(engine) as ses:
user = ses.scalar(select(User).where(User.email == email_norm))
if not user or not user.is_verified:
# Avoid account discovery
return GenericMessageOut(detail="If an account exists, you'll receive a reset code shortly.")
code = _issue_email_code(ses, user, "reset_password", PASSWORD_RESET_TTL_SECONDS)
ses.commit()
_send_password_reset_email(user, code)
return GenericMessageOut(detail="Check your email for a password reset code.")
@app.post("/api/auth/password/reset", response_model=GenericMessageOut)
def reset_password(payload: PasswordResetIn):
email_norm = str(payload.email).strip().lower()
code = payload.code.strip()
with Session(engine) as ses:
user = ses.scalar(select(User).where(User.email == email_norm))
if not user:
raise HTTPException(status_code=400, detail="Invalid or expired reset code")
match = _verify_email_code(ses, user, "reset_password", code)
if not match:
raise HTTPException(status_code=400, detail="Invalid or expired reset code")
user.password_hash = _hash_password(payload.new_password)
if not user.is_verified:
user.is_verified = 1
user.verified_at = time.time()
ses.execute(text("DELETE FROM sessions WHERE user_id = :uid"), {"uid": user.id})
ses.commit()
return GenericMessageOut(detail="Password updated. You can now sign in with your new password.")
@app.get("/api/auth/me", response_model=AuthMeOut)
def whoami(request: Request):
token = _extract_bearer_token(request)
with Session(engine) as ses:
user = _get_user_from_token(token, ses, require=True)
return AuthMeOut(user=UserOut(**_serialize_user(user)))
@app.post("/api/auth/logout")
def logout_user(request: Request):
token = _extract_bearer_token(request)
if not token:
raise HTTPException(status_code=401, detail="Missing authorization")
with Session(engine) as ses:
sess = ses.get(SessionToken, token)
if not sess:
raise HTTPException(status_code=401, detail="Invalid session")
ses.delete(sess)
ses.commit()
return {"status": "ok"}
@app.post("/api/model/select")
def select_model(request: Request, model_key: str = Form(...)):
"""
Update the preferred segmentation model for the authenticated user.
Accepts 'green' or 'orange'. Anything else falls back to 'green'.
"""
token = _extract_bearer_token(request)
with Session(engine) as ses:
user = _get_user_from_token(token, ses, require=True)
normalized = _normalize_model_key(model_key)
user.preferred_model = normalized
ses.commit()
_ensure_model_loaded(normalized)
return {"status": "ok", "model_key": normalized, "device": DEVICE}
@app.get("/api/model/current")
def current_model(request: Request):
token = _extract_bearer_token(request)
with Session(engine) as ses:
user = _get_user_from_token(token, ses)
key = _normalize_model_key(user.preferred_model if user else None)
return {"model_key": key, "device": DEVICE}
@app.post("/api/upload", response_model=UploadOut)
async def upload(request: Request, file: UploadFile = File(...)):
# Read & validate image
raw = await file.read()
if len(raw) > MAX_BYTES:
raise HTTPException(status_code=413, detail="Image too large (max 10MB)")
try:
with Image.open(io.BytesIO(raw)) as probe:
probe.verify()
img = Image.open(io.BytesIO(raw)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
w, h = img.size
upload_id = str(uuid.uuid4())
ts = time.time()
stored_name = f"{upload_id}_input.jpg"
in_path = os.path.join(UPLOAD_DIR, stored_name)
img.save(in_path, quality=92, optimize=True)
# persist upload metadata (original filename)
with Session(engine) as ses:
user = _get_user_from_token(_extract_bearer_token(request), ses)
user_id = user.id if user else None
ses.add(
Upload(
upload_id=upload_id,
created_at=ts,
stored_name=stored_name,
orig_name=os.path.basename(file.filename) if file.filename else "unknown",
width=w,
height=h,
user_id=user_id,
)
)
ses.commit()
input_url = request.url_for("media", path=f"uploads/{stored_name}")
return UploadOut(
upload_id=upload_id,
created_at=ts,
input_url=str(input_url),
width=w,
height=h,
)
@app.post("/api/segment", response_model=JobOut)
async def segment(
request: Request,
file: Optional[UploadFile] = File(None), # backward-compat
upload_id: Optional[str] = Form(None), # preferred path
prompt: Optional[str] = Form(None), # optional text input
):
"""
Preferred: pass upload_id (image already uploaded via /api/upload).
Backward compatibility: can still post a 'file'.
If 'prompt' is provided, segmentation will be prompt-guided.
"""
# --- Normalize prompt into list, prepend "background" if non-empty ---
if prompt:
prompt_list = [p.strip() for p in prompt.split(",") if p.strip()]
if prompt_list:
if prompt_list[0].lower() != "background":
prompt_list = ["background"] + prompt_list
else:
prompt_list = None
else:
prompt_list = None
token = _extract_bearer_token(request)
user_id: Optional[str] = None
user_model_key = DEFAULT_MODEL_KEY
if token:
with Session(engine) as ses:
user = _get_user_from_token(token, ses)
if user:
user_id = user.id
user_model_key = _normalize_model_key(getattr(user, "preferred_model", None))
# Load the image & get original filename
if upload_id and not file:
# two-step path: fetch stored image + original name
in_name = f"{upload_id}_input.jpg"
in_path = os.path.join(UPLOAD_DIR, in_name)
if not os.path.exists(in_path):
raise HTTPException(status_code=404, detail="Upload not found")
try:
img = Image.open(in_path).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
# get orig filename from uploads table
with Session(engine) as ses:
upl = ses.get(Upload, upload_id)
if upl:
if upl.user_id and upl.user_id != user_id:
raise HTTPException(status_code=403, detail="Upload does not belong to this account")
orig_name = upl.orig_name or in_name
else:
orig_name = in_name # fallback
else:
# one-step legacy path: handle direct file upload to /segment
if not file:
raise HTTPException(status_code=400, detail="No image or upload_id provided")
raw = await file.read()
if len(raw) > MAX_BYTES:
raise HTTPException(status_code=413, detail="Image too large (max 10MB)")
try:
with Image.open(io.BytesIO(raw)) as probe:
probe.verify()
img = Image.open(io.BytesIO(raw)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
orig_name = os.path.basename(file.filename) if file.filename else "unknown"
w, h = img.size
# Inference
model, user_model_key = _ensure_model_loaded(user_model_key)
try:
if prompt_list:
mask = segment_image(model, img, prompt_list, SENTENCE_MODEL, PCA_PARAMS, device=DEVICE)
else:
mask = segment_image(model, img, device=DEVICE)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference failed: {e}")
dataset = "ade" if user_model_key == "orange" else "coco"
# Prompt-aware coloring & legend
if prompt_list:
color = colorize_mask(mask, classes=prompt_list)
legend = build_legend_from_mask(mask, classes=prompt_list)
else:
color = colorize_mask(mask, dataset=dataset)
legend = build_legend_from_mask(mask, dataset=dataset)
over = overlay_mask(img, color, alpha=0.45)
# Save artifacts
job_id = str(uuid.uuid4())
ts = time.time()
in_name_final = f"{job_id}_input.jpg"
mask_name = f"{job_id}_mask.png"
over_name = f"{job_id}_overlay.jpg"
legend_name = f"{job_id}_legend.json"
in_path_final = os.path.join(MEDIA_ROOT, in_name_final)
mask_path = os.path.join(MEDIA_ROOT, mask_name)
over_path = os.path.join(MEDIA_ROOT, over_name)
legend_path = os.path.join(MEDIA_ROOT, legend_name)
img.save(in_path_final, quality=92, optimize=True)
color.save(mask_path)
over.save(over_path, quality=92, optimize=True)
try:
with open(legend_path, "w", encoding="utf-8") as f:
json.dump(legend, f)
except Exception:
legend_path = None
# Persist job row
with Session(engine) as ses:
ses.add(
Job(
id=job_id,
created_at=ts,
input_path=in_name_final,
input_name=orig_name or "unknown", # <- show this in history
mask_path=mask_name,
overlay_path=over_name,
model="Meruformerb4_entailonly",
device=DEVICE,
width=w,
height=h,
user_id=user_id,
)
)
ses.commit()
input_url = request.url_for("media", path=in_name_final)
mask_url = request.url_for("media", path=mask_name)
overlay_url = request.url_for("media", path=over_name)
return JobOut(
id=job_id,
created_at=ts,
input_url=str(input_url),
mask_url=str(mask_url),
overlay_url=str(overlay_url),
model="Meruformerb4_entailonly",
device=DEVICE,
width=w,
height=h,
input_name=orig_name or "unknown",
legend=legend,
)
@app.get("/api/history", response_model=List[JobOut])
def history(request: Request, limit: int = 20, offset: int = 0):
token = _extract_bearer_token(request)
with Session(engine) as ses:
user = _get_user_from_token(token, ses)
params: Dict[str, Any] = {"limit": limit, "offset": offset}
if user:
query = text(
"""
SELECT id, created_at, input_path, mask_path, overlay_path, model, device, width, height, input_name
FROM jobs
WHERE user_id = :user_id
ORDER BY created_at DESC
LIMIT :limit OFFSET :offset
"""
)
params["user_id"] = user.id
else:
query = text(
"""
SELECT id, created_at, input_path, mask_path, overlay_path, model, device, width, height, input_name
FROM jobs
WHERE user_id IS NULL
ORDER BY created_at DESC
LIMIT :limit OFFSET :offset
"""
)
rows = ses.execute(query, params).all()
out: List[JobOut] = []
for r in rows:
job_id = r[0]
legend = None
sidecar = os.path.join(MEDIA_ROOT, f"{job_id}_legend.json")
if os.path.exists(sidecar):
try:
with open(sidecar, "r", encoding="utf-8") as f:
legend = json.load(f)
except Exception:
legend = None
out.append(
JobOut(
id=job_id,
created_at=r[1],
input_url=str(request.url_for("media", path=r[2])),
mask_url=str(request.url_for("media", path=r[3])),
overlay_url=str(request.url_for("media", path=r[4])),
model=r[5],
device=r[6],
width=r[7],
height=r[8],
input_name=r[9] or "unknown",
legend=legend, # <— attach legend if available
)
)
return out
@app.delete("/api/history/{job_id}")
def delete_job(job_id: str, request: Request):
with Session(engine) as ses:
user = _get_user_from_token(_extract_bearer_token(request), ses)
row = ses.get(Job, job_id)
if not row:
raise HTTPException(status_code=404, detail="Job not found")
if row.user_id:
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
if user.id != row.user_id:
raise HTTPException(status_code=403, detail="Forbidden")
# delete files too
for path in [row.input_path, row.mask_path, row.overlay_path]:
try:
os.remove(os.path.join(MEDIA_ROOT, path))
except Exception:
pass
# delete legend sidecar
try:
os.remove(os.path.join(MEDIA_ROOT, f"{job_id}_legend.json"))
except Exception:
pass
ses.delete(row)
ses.commit()
return {"status": "deleted", "id": job_id}
@app.delete("/api/history")
def clear_history(request: Request):
with Session(engine) as ses:
user = _get_user_from_token(_extract_bearer_token(request), ses)
if user:
rows = ses.query(Job).filter(Job.user_id == user.id).all()
else:
rows = ses.query(Job).filter(Job.user_id.is_(None)).all()
for row in rows:
for path in [row.input_path, row.mask_path, row.overlay_path]:
try:
os.remove(os.path.join(MEDIA_ROOT, path))
except Exception:
pass
try:
os.remove(os.path.join(MEDIA_ROOT, f"{row.id}_legend.json"))
except Exception:
pass
ses.delete(row)
ses.commit()
return {"status": "cleared"}
# -------------------- Contact, Feedback (+Survey autodetect) --------------------
@app.post("/api/contact")
def create_contact(payload: ContactIn):
ts = time.time()
row_id = str(uuid.uuid4())
with Session(engine) as ses:
ses.add(Contact(
id=row_id,
created_at=ts,
name=payload.name,
email=str(payload.email),
subject=payload.subject, # <-- save it
message=payload.message
))
ses.commit()
# Build subject line
subj = (payload.subject or "").strip()
email_subject = f"ConeImage: {subj}" if subj else "ConeImage: No subject"
# Email the submission; set Reply-To to submitter's email
try:
_send_mail(
to=CONTACT_RECIPIENTS,
subject=email_subject, # <-- use user's subject or fallback
body=(
f"Name: {payload.name}\n"
f"Email: {payload.email}\n"
f"Subject: {subj or 'No subject'}\n"
f"Message:\n{payload.message}\n\n"
f"Submitted at: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ts))} UTC"
),
reply_to=str(payload.email),
)
except Exception:
pass
return {"status": "ok", "id": row_id, "created_at": ts}
@app.post("/api/feedback")
def create_feedback(payload: FeedbackIn):
"""
Accepts plain feedback (no rating) OR a survey JSON (tag='SITE_SURVEY')
posted as the 'feedback' string. If a survey is detected, it is saved
into the 'surveys' table; otherwise, it is saved into 'feedback'.
For plain feedback, send an email that includes only the submitter's email.
"""
ts = time.time()
row_id = str(uuid.uuid4())
# Detect and store site survey if present (no email required)
survey = _try_parse_site_survey(payload.feedback)
if survey is not None:
with Session(engine) as ses:
ses.add(Survey(
id=row_id,
created_at=ts,
role=survey["role"],
use_case=survey["use_case"],
model_works=survey["model_works"],
image_visualization=survey["image_visualization"],
prompt_input=survey["prompt_input"],
processing_speed=survey["processing_speed"],
accessibility=survey["accessibility"],
overall_experience=survey["overall_experience"],
))
ses.commit()
return {"status": "ok", "kind": "survey", "id": row_id, "created_at": ts}
# Otherwise store as regular feedback (no rating field)
with Session(engine) as ses:
ses.add(Feedback(
id=row_id,
created_at=ts,
name=payload.name,
email=str(payload.email) if payload.email else None,
feedback=payload.feedback
))
ses.commit()
# Email only the submitter's email address to recipients (per request)
try:
_send_mail(
to=CONTACT_RECIPIENTS,
subject="ConeImage: New Feedback",
body=(
f"Name: {payload.name or '(anonymous)'}\n"
f"Email: {payload.email or '(none provided)'}\n"
f"Message:\n{payload.feedback}\n\n"
f"Submitted at: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(ts))} UTC"
),
reply_to=None,
)
except Exception:
pass
return {"status": "ok", "kind": "feedback", "id": row_id, "created_at": ts}