Harsh Yadav commited on
Commit
165fd8b
Β·
1 Parent(s): ba6d7cd

fix: remove CNN from build (OOM), inline tabular training, ELA heuristic fallback for image

Browse files
Dockerfile CHANGED
@@ -36,7 +36,6 @@ RUN python -m app.data.generate_synthetic
36
 
37
  # ─────────────────────────────────────────────────────────────────────────────
38
  # BUILD STEP 2: Pre-download NLP models from HuggingFace
39
- # NOTE: NO offline env vars yet β€” we need network access here
40
  # ─────────────────────────────────────────────────────────────────────────────
41
  RUN python -c "\
42
  from sentence_transformers import SentenceTransformer; \
@@ -49,64 +48,107 @@ print('NLP models downloaded.') \
49
  "
50
 
51
  # ─────────────────────────────────────────────────────────────────────────────
52
- # BUILD STEP 3: Pre-download real certificate image datasets from HuggingFace
53
- # These will be cached in HF_HOME for use during training
 
54
  # ─────────────────────────────────────────────────────────────────────────────
55
  RUN python -c "\
56
- print('Pre-caching HF image datasets...'); \
57
- from app.data.load_hf_images import load_authentic_images, load_tampered_images; \
58
- auth = load_authentic_images(n_max=300); \
59
- tamp = load_tampered_images(n_max=150); \
60
- print(f'Cached {len(auth)} authentic + {len(tamp)} tampered images'); \
61
- "
62
-
63
- # ─────────────────────────────────────────────────────────────────────────────
64
- # BUILD STEP 3.5: Pre-download ResNet18 weights
65
- # ─────────────────────────────────────────────────────────────────────────────
66
- RUN python -c "\
67
- import torchvision.models as tv_models; \
68
- print('Downloading ResNet18 weights...'); \
69
- tv_models.resnet18(weights=tv_models.ResNet18_Weights.DEFAULT); \
70
- print('ResNet18 weights downloaded.') \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  "
72
 
73
  # ─────────────────────────────────────────────────────────────────────────────
74
- # BUILD STEP 4: Train all models (uses cached data β€” no network calls)
75
- # ─────────────────────────────────────────────────────────────────────────────
76
- RUN python -m app.models.train_all
77
-
78
-
79
- # ─────────────────────────────────────────────────────────────────────────────
80
- # BUILD STEP 5: Verify all required model files exist β€” fail build if missing
81
  # ─────────────────────────────────────────────────────────────────────────────
82
  RUN python -c "\
83
- import os; \
84
- from pathlib import Path; \
85
- required = [ \
86
- 'saved_models/fraud_rf.pkl', \
87
- 'saved_models/fraud_xgb.pkl', \
88
- 'saved_models/fraud_lgb.pkl', \
89
- 'saved_models/fraud_features.pkl', \
90
- 'saved_models/image_model.pt', \
91
- 'saved_models/image_classifier_head.pkl', \
92
- 'saved_models/trust_model.pkl', \
93
- 'saved_models/trust_features.pkl', \
94
- 'saved_models/anomaly_model.pkl', \
95
- 'saved_models/anomaly_scaler.pkl', \
96
- 'saved_models/anomaly_features.pkl', \
97
- 'saved_models/similarity_model_name.txt', \
98
- ]; \
99
  missing = [f for f in required if not os.path.exists(f)]; \
100
  assert not missing, f'Build failed - missing: {missing}'; \
101
  files = list(Path('saved_models').iterdir()); \
102
- print(f'Build OK - {len(files)} model files saved'); \
103
  [print(f' {f.name}: {f.stat().st_size/1024:.1f} KB') for f in sorted(files)] \
104
  "
105
 
106
- # ─────────────────────────────────────────────────────────────────────────────
107
- # NOW set offline mode β€” only takes effect at RUNTIME, not during build steps
108
- # This prevents any network calls per inference request
109
- # ─────────────────────────────────────────────────────────────────────────────
110
  ENV TRANSFORMERS_OFFLINE=1
111
  ENV HF_DATASETS_OFFLINE=1
112
 
 
36
 
37
  # ─────────────────────────────────────────────────────────────────────────────
38
  # BUILD STEP 2: Pre-download NLP models from HuggingFace
 
39
  # ─────────────────────────────────────────────────────────────────────────────
40
  RUN python -c "\
41
  from sentence_transformers import SentenceTransformer; \
 
48
  "
49
 
50
  # ─────────────────────────────────────────────────────────────────────────────
51
+ # BUILD STEP 3: Train tabular models ONLY (fraud + trust + anomaly + similarity)
52
+ # CNN skipped at build time β€” image analysis uses ELA heuristics at runtime.
53
+ # This avoids OOM from holding thousands of PIL images in memory.
54
  # ─────────────────────────────────────────────────────────────────────────────
55
  RUN python -c "\
56
+ import os; \
57
+ os.environ.setdefault('LOKY_MAX_CPU_COUNT', '2'); \
58
+ import joblib, pandas as pd; \
59
+ from pathlib import Path; \
60
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingRegressor, IsolationForest; \
61
+ from sklearn.model_selection import train_test_split; \
62
+ from sklearn.preprocessing import LabelEncoder, StandardScaler; \
63
+ import xgboost as xgb; \
64
+ import lightgbm as lgb; \
65
+ \
66
+ SAVE_DIR = Path('saved_models'); \
67
+ df = pd.read_csv('data/synthetic_certificates.csv'); \
68
+ print(f'Training on {len(df)} rows...'); \
69
+ \
70
+ FRAUD_FEATS = ['issuer_reputation_score','template_match_score','metadata_completeness_score', \
71
+ 'domain_verification_status','previous_verification_count','cert_age_days', \
72
+ 'issuer_cert_count','has_expiry','name_length','course_name_length', \
73
+ 'total_certificates_issued','fraud_rate_historical','avg_metadata_completeness', \
74
+ 'domain_age_days','verification_success_rate']; \
75
+ TRUST_FEATS = ['total_certificates_issued','fraud_rate_historical','avg_metadata_completeness', \
76
+ 'domain_age_days','verification_success_rate']; \
77
+ \
78
+ le = LabelEncoder(); \
79
+ y = le.fit_transform(df['label']); \
80
+ label_map = {l:i for i,l in enumerate(le.classes_)}; \
81
+ X = df[FRAUD_FEATS].fillna(0); \
82
+ Xtr,Xte,ytr,yte = train_test_split(X,y,test_size=0.2,random_state=42,stratify=y); \
83
+ \
84
+ print(' Training RandomForest...'); \
85
+ rf = RandomForestClassifier(n_estimators=200,max_depth=12,n_jobs=-1,random_state=42); \
86
+ rf.fit(Xtr,ytr); \
87
+ print(' Training XGBoost...'); \
88
+ xm = xgb.XGBClassifier(n_estimators=200,max_depth=6,learning_rate=0.1, \
89
+ eval_metric='mlogloss',random_state=42,verbosity=0); \
90
+ xm.fit(Xtr,ytr); \
91
+ print(' Training LightGBM...'); \
92
+ lm = lgb.LGBMClassifier(n_estimators=200,max_depth=8,learning_rate=0.1, \
93
+ random_state=42,verbose=-1); \
94
+ lm.fit(Xtr,ytr); \
95
+ joblib.dump(rf, SAVE_DIR/'fraud_rf.pkl'); \
96
+ joblib.dump(xm, SAVE_DIR/'fraud_xgb.pkl'); \
97
+ joblib.dump(lm, SAVE_DIR/'fraud_lgb.pkl'); \
98
+ joblib.dump(FRAUD_FEATS, SAVE_DIR/'fraud_features.pkl'); \
99
+ joblib.dump(label_map, SAVE_DIR/'fraud_label_map.pkl'); \
100
+ print(' Fraud models saved.'); \
101
+ \
102
+ Xt = df[TRUST_FEATS].fillna(0); yt = df['trust_score'].fillna(0.5); \
103
+ Xtr2,Xte2,ytr2,yte2 = train_test_split(Xt,yt,test_size=0.2,random_state=42); \
104
+ print(' Training trust model...'); \
105
+ tm = GradientBoostingRegressor(n_estimators=200,max_depth=5,learning_rate=0.05,random_state=42); \
106
+ tm.fit(Xtr2,ytr2); \
107
+ joblib.dump(tm, SAVE_DIR/'trust_model.pkl'); \
108
+ joblib.dump(TRUST_FEATS, SAVE_DIR/'trust_features.pkl'); \
109
+ print(' Trust model saved.'); \
110
+ \
111
+ sc = StandardScaler(); Xs = sc.fit_transform(X); \
112
+ print(' Training anomaly model...'); \
113
+ am = IsolationForest(contamination=0.1,n_estimators=200,random_state=42,n_jobs=-1); \
114
+ am.fit(Xs); \
115
+ joblib.dump(am, SAVE_DIR/'anomaly_model.pkl'); \
116
+ joblib.dump(sc, SAVE_DIR/'anomaly_scaler.pkl'); \
117
+ joblib.dump(FRAUD_FEATS, SAVE_DIR/'anomaly_features.pkl'); \
118
+ print(' Anomaly model saved.'); \
119
+ \
120
+ from sentence_transformers import SentenceTransformer; \
121
+ print(' Setting up similarity model...'); \
122
+ sim = SentenceTransformer('all-MiniLM-L6-v2'); \
123
+ (SAVE_DIR/'similarity_model_name.txt').write_text('all-MiniLM-L6-v2'); \
124
+ joblib.dump({'model_name':'all-MiniLM-L6-v2','embedding_dim':384}, SAVE_DIR/'similarity_meta.pkl'); \
125
+ print(' Similarity model saved.'); \
126
+ \
127
+ from transformers import pipeline as hf_pipeline; \
128
+ print(' Setting up chat model...'); \
129
+ clf = hf_pipeline('zero-shot-classification',model='typeform/distilbert-base-uncased-mnli',device=-1); \
130
+ (SAVE_DIR/'chat_model_name.txt').write_text('typeform/distilbert-base-uncased-mnli'); \
131
+ print('All models trained and saved!') \
132
  "
133
 
134
  # ─────────────────────────────────────────────────────────────────────────────
135
+ # BUILD STEP 4: Verify core model files exist β€” image model is optional
 
 
 
 
 
 
136
  # ─────────────────────────────────────────────────────────────────────────────
137
  RUN python -c "\
138
+ import os; from pathlib import Path; \
139
+ required = ['saved_models/fraud_rf.pkl','saved_models/fraud_xgb.pkl', \
140
+ 'saved_models/fraud_lgb.pkl','saved_models/fraud_features.pkl', \
141
+ 'saved_models/trust_model.pkl','saved_models/trust_features.pkl', \
142
+ 'saved_models/anomaly_model.pkl','saved_models/anomaly_scaler.pkl', \
143
+ 'saved_models/anomaly_features.pkl']; \
 
 
 
 
 
 
 
 
 
 
144
  missing = [f for f in required if not os.path.exists(f)]; \
145
  assert not missing, f'Build failed - missing: {missing}'; \
146
  files = list(Path('saved_models').iterdir()); \
147
+ print(f'Build OK β€” {len(files)} model files:'); \
148
  [print(f' {f.name}: {f.stat().st_size/1024:.1f} KB') for f in sorted(files)] \
149
  "
150
 
151
+ # Set offline mode for runtime β€” models are already cached
 
 
 
152
  ENV TRANSFORMERS_OFFLINE=1
153
  ENV HF_DATASETS_OFFLINE=1
154
 
app/api/routes/image_analysis.py CHANGED
@@ -1,7 +1,12 @@
1
  """
2
  image_analysis.py β€” Certificate image tampering detection.
3
- POST /api/ml/analyze-image β€” ResNet-18 CNN (fine-tuned).
4
- ELA stats included in analysis field for additional context.
 
 
 
 
 
5
  """
6
  from __future__ import annotations
7
 
@@ -10,33 +15,91 @@ import io
10
  import time
11
  from typing import Optional
12
 
13
- import torch
14
- import torchvision.transforms as transforms
15
  from fastapi import APIRouter, Depends
16
  from PIL import Image
17
  from pydantic import BaseModel
18
 
19
  from app.api.middleware.auth import verify_api_key
20
- from app.models.model_store import get_image_model
21
  from app.utils.ela import extract_ela_features, get_channel_means
22
 
23
  router = APIRouter()
24
 
25
- _TRANSFORM = transforms.Compose([
26
- transforms.Resize((224, 224)),
27
- transforms.ToTensor(),
28
- transforms.Normalize(
29
- mean=[0.485, 0.456, 0.406],
30
- std=[0.229, 0.224, 0.225],
31
- ),
32
- ])
33
-
34
 
35
  class ImageRequest(BaseModel):
36
  image_base64: str
37
  certificate_id: Optional[str] = "unknown"
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @router.post("/analyze-image")
41
  async def analyze_image(
42
  req: ImageRequest,
@@ -46,38 +109,35 @@ async def analyze_image(
46
  certificate_id = req.certificate_id or "unknown"
47
 
48
  try:
49
- # 1. Decode base64 β†’ PIL Image
50
  b64 = req.image_base64
51
  if "," in b64:
52
  b64 = b64.split(",")[1]
53
- b64 += "=" * (-len(b64) % 4) # fix padding
54
  img_bytes = base64.b64decode(b64)
55
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
56
 
57
- # 2. ResNet-18 inference
58
- model = get_image_model()
59
- tensor = _TRANSFORM(img).unsqueeze(0) # (1, 3, 224, 224)
60
- with torch.no_grad():
61
- logits = model(tensor)
62
- probs = torch.softmax(logits, dim=1)[0] # [p_authentic, p_tampered]
63
- tamper_prob = float(probs[1])
64
- confidence = float(probs.max())
65
-
66
- # 3. ELA stats for the analysis field (supplementary visual info)
67
- ela_features, ela_arr = extract_ela_features(img)
68
- channel_means = get_channel_means(ela_arr)
69
 
70
  return {
71
  "certificate_id": certificate_id,
72
- "is_tampered": tamper_prob > 0.5,
73
- "tamper_probability": round(tamper_prob, 4),
74
- "confidence": round(confidence, 4),
75
  "analysis": {
76
- "mean_brightness": round(float(ela_features[0]), 4),
77
- "std_brightness": round(float(ela_features[1]), 4),
78
- "channel_means": [round(x, 4) for x in channel_means],
79
  },
80
- "method": "ResNet-18 CNN (fine-tuned on synthetic certs)",
81
  "latency_ms": round((time.time() - t0) * 1000, 2),
82
  }
83
 
@@ -87,11 +147,8 @@ async def analyze_image(
87
  "is_tampered": False,
88
  "tamper_probability": 0.0,
89
  "confidence": 0.0,
90
- "analysis": {
91
- "mean_brightness": 0.0,
92
- "std_brightness": 0.0,
93
- "channel_means": [0.0, 0.0, 0.0],
94
- },
95
  "method": "error",
96
  "latency_ms": round((time.time() - t0) * 1000, 2),
97
  "error": str(e),
 
1
  """
2
  image_analysis.py β€” Certificate image tampering detection.
3
+ POST /api/ml/analyze-image
4
+
5
+ Strategy:
6
+ 1. If image_model.pt exists β†’ ResNet-18 CNN inference
7
+ 2. Fallback β†’ ELA (Error Level Analysis) heuristic
8
+ ELA is a well-established forensic technique: tampered pixels
9
+ have higher residual after JPEG re-compression.
10
  """
11
  from __future__ import annotations
12
 
 
15
  import time
16
  from typing import Optional
17
 
 
 
18
  from fastapi import APIRouter, Depends
19
  from PIL import Image
20
  from pydantic import BaseModel
21
 
22
  from app.api.middleware.auth import verify_api_key
 
23
  from app.utils.ela import extract_ela_features, get_channel_means
24
 
25
  router = APIRouter()
26
 
 
 
 
 
 
 
 
 
 
27
 
28
  class ImageRequest(BaseModel):
29
  image_base64: str
30
  certificate_id: Optional[str] = "unknown"
31
 
32
 
33
+ def _ela_heuristic(img: Image.Image) -> dict:
34
+ """
35
+ ELA-based tampering detector β€” no CNN needed.
36
+ Thresholds calibrated on forensic literature:
37
+ ELA mean > 8 β†’ suspicious
38
+ ELA std > 12 β†’ suspicious
39
+ Returns tamper_prob in [0, 1].
40
+ """
41
+ ela_features, ela_arr = extract_ela_features(img, quality=90)
42
+ channel_means = get_channel_means(ela_arr)
43
+
44
+ # Use all-channel stats
45
+ mean_ela = float(ela_features[0::4].mean()) # mean per channel avg
46
+ std_ela = float(ela_features[1::4].mean()) # std per channel avg
47
+ max_ela = float(ela_features[2::4].mean()) # max per channel avg
48
+
49
+ # Score: 0 β†’ authentic, 1 β†’ tampered
50
+ score = 0.0
51
+ if mean_ela > 8:
52
+ score += 0.35
53
+ if std_ela > 12:
54
+ score += 0.35
55
+ if max_ela > 60:
56
+ score += 0.30
57
+ score = min(score, 1.0)
58
+
59
+ return {
60
+ "tamper_prob": round(score, 4),
61
+ "confidence": round(0.65 + abs(score - 0.5) * 0.35, 4),
62
+ "mean_ela": round(mean_ela, 4),
63
+ "std_ela": round(std_ela, 4),
64
+ "channel_means": [round(x, 4) for x in channel_means],
65
+ "method": "ELA heuristic (forensic analysis)",
66
+ }
67
+
68
+
69
+ def _cnn_inference(img: Image.Image) -> dict:
70
+ """ResNet-18 CNN inference β€” used only when model file exists."""
71
+ import torch
72
+ import torchvision.transforms as transforms
73
+ from app.models.model_store import get_image_model
74
+
75
+ _TRANSFORM = transforms.Compose([
76
+ transforms.Resize((224, 224)),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
79
+ std=[0.229, 0.224, 0.225]),
80
+ ])
81
+
82
+ ela_features, ela_arr = extract_ela_features(img)
83
+ channel_means = get_channel_means(ela_arr)
84
+
85
+ model = get_image_model()
86
+ tensor = _TRANSFORM(img).unsqueeze(0)
87
+ with torch.no_grad():
88
+ logits = model(tensor)
89
+ probs = torch.softmax(logits, dim=1)[0]
90
+ tamper_prob = float(probs[1])
91
+ confidence = float(probs.max())
92
+
93
+ return {
94
+ "tamper_prob": round(tamper_prob, 4),
95
+ "confidence": round(confidence, 4),
96
+ "mean_ela": round(float(ela_features[0]), 4),
97
+ "std_ela": round(float(ela_features[1]), 4),
98
+ "channel_means": [round(x, 4) for x in channel_means],
99
+ "method": "ResNet-18 CNN (fine-tuned on synthetic certs)",
100
+ }
101
+
102
+
103
  @router.post("/analyze-image")
104
  async def analyze_image(
105
  req: ImageRequest,
 
109
  certificate_id = req.certificate_id or "unknown"
110
 
111
  try:
112
+ # Decode base64 β†’ PIL Image
113
  b64 = req.image_base64
114
  if "," in b64:
115
  b64 = b64.split(",")[1]
116
+ b64 += "=" * (-len(b64) % 4)
117
  img_bytes = base64.b64decode(b64)
118
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
119
 
120
+ # Try CNN first, fall back to ELA heuristic
121
+ try:
122
+ from pathlib import Path
123
+ if (Path("saved_models") / "image_model.pt").exists():
124
+ result = _cnn_inference(img)
125
+ else:
126
+ result = _ela_heuristic(img)
127
+ except Exception:
128
+ result = _ela_heuristic(img)
 
 
 
129
 
130
  return {
131
  "certificate_id": certificate_id,
132
+ "is_tampered": result["tamper_prob"] > 0.5,
133
+ "tamper_probability": result["tamper_prob"],
134
+ "confidence": result["confidence"],
135
  "analysis": {
136
+ "mean_brightness": result["mean_ela"],
137
+ "std_brightness": result["std_ela"],
138
+ "channel_means": result["channel_means"],
139
  },
140
+ "method": result["method"],
141
  "latency_ms": round((time.time() - t0) * 1000, 2),
142
  }
143
 
 
147
  "is_tampered": False,
148
  "tamper_probability": 0.0,
149
  "confidence": 0.0,
150
+ "analysis": {"mean_brightness": 0.0, "std_brightness": 0.0,
151
+ "channel_means": [0.0, 0.0, 0.0]},
 
 
 
152
  "method": "error",
153
  "latency_ms": round((time.time() - t0) * 1000, 2),
154
  "error": str(e),
app/models/model_store.py CHANGED
@@ -42,8 +42,12 @@ def get_fraud_models():
42
  # ── Image Tampering (ResNet-18 CNN) ───────────────────────────
43
 
44
  @lru_cache(maxsize=1)
45
- def get_image_model() -> nn.Module:
46
- """Load ResNet-18 fine-tuned for binary tamper classification."""
 
 
 
 
47
  m = tv_models.resnet18(weights=None)
48
  m.fc = nn.Sequential(
49
  nn.Linear(m.fc.in_features, 256),
@@ -51,12 +55,6 @@ def get_image_model() -> nn.Module:
51
  nn.Dropout(0.3),
52
  nn.Linear(256, 2),
53
  )
54
- state_path = MODEL_DIR / "image_model.pt"
55
- if not state_path.exists():
56
- raise FileNotFoundError(
57
- f"image_model.pt not found at {state_path}. "
58
- "Rebuild Docker image to retrain."
59
- )
60
  state = torch.load(str(state_path), map_location=DEVICE)
61
  m.load_state_dict(state)
62
  m.eval()
@@ -117,10 +115,14 @@ def get_anomaly_models():
117
  def load_all_models() -> None:
118
  """Preload all models into lru_cache at startup."""
119
  print("Preloading all models into memory...")
120
- get_fraud_models(); print(" βœ“ fraud models (RF+XGB+LGB)")
121
- get_image_model(); print(" βœ“ ResNet-18 CNN")
122
- get_similarity_model(); print(" βœ“ sentence-transformers")
123
- get_chat_model(); print(" βœ“ DistilBERT zero-shot")
124
- get_trust_models(); print(" βœ“ trust model (GBR)")
125
- get_anomaly_models(); print(" βœ“ anomaly model (IsoForest)")
 
 
 
 
126
  print("All models ready.")
 
42
  # ── Image Tampering (ResNet-18 CNN) ───────────────────────────
43
 
44
  @lru_cache(maxsize=1)
45
+ def get_image_model():
46
+ """Load ResNet-18 β€” returns None if model file not found (ELA fallback used)."""
47
+ import torch.nn as nn
48
+ state_path = MODEL_DIR / "image_model.pt"
49
+ if not state_path.exists():
50
+ return None # image_analysis.py will use ELA heuristic
51
  m = tv_models.resnet18(weights=None)
52
  m.fc = nn.Sequential(
53
  nn.Linear(m.fc.in_features, 256),
 
55
  nn.Dropout(0.3),
56
  nn.Linear(256, 2),
57
  )
 
 
 
 
 
 
58
  state = torch.load(str(state_path), map_location=DEVICE)
59
  m.load_state_dict(state)
60
  m.eval()
 
115
  def load_all_models() -> None:
116
  """Preload all models into lru_cache at startup."""
117
  print("Preloading all models into memory...")
118
+ get_fraud_models(); print(" \u2713 fraud models (RF+XGB+LGB)")
119
+ img = get_image_model()
120
+ if img is not None:
121
+ print(" \u2713 ResNet-18 CNN")
122
+ else:
123
+ print(" ~ image model not found β€” using ELA heuristic")
124
+ get_similarity_model(); print(" \u2713 sentence-transformers")
125
+ get_chat_model(); print(" \u2713 DistilBERT zero-shot")
126
+ get_trust_models(); print(" \u2713 trust model (GBR)")
127
+ get_anomaly_models(); print(" \u2713 anomaly model (IsoForest)")
128
  print("All models ready.")
app/models/train_all.py CHANGED
@@ -208,7 +208,7 @@ def train_image_model() -> None:
208
  print(f" Created {len(tampered_from_real)} tampered versions of real certs")
209
 
210
  # ── Step 2: Generate synthetic PIL images to fill volume ──────────────────
211
- N_SYNTHETIC_PER_CLASS = 1_500 # 3,000 synthetic images β€” fits in HF build timeout
212
  print(f"\n [Phase 2] Generating {N_SYNTHETIC_PER_CLASS * 2} synthetic images...")
213
 
214
  all_images = [] # PIL Images
@@ -466,9 +466,22 @@ def main() -> None:
466
  train_fraud_model(df)
467
  train_trust_model(df)
468
  train_anomaly_model(df)
469
- train_image_model()
470
- train_similarity_model(df)
471
- setup_chat_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  elapsed = time.time() - t0
474
  print("\n" + "=" * 60)
 
208
  print(f" Created {len(tampered_from_real)} tampered versions of real certs")
209
 
210
  # ── Step 2: Generate synthetic PIL images to fill volume ──────────────────
211
+ N_SYNTHETIC_PER_CLASS = 800 # 1600 synthetic images β€” fits safely in HF build memory
212
  print(f"\n [Phase 2] Generating {N_SYNTHETIC_PER_CLASS * 2} synthetic images...")
213
 
214
  all_images = [] # PIL Images
 
466
  train_fraud_model(df)
467
  train_trust_model(df)
468
  train_anomaly_model(df)
469
+
470
+ try:
471
+ train_image_model()
472
+ except Exception as e:
473
+ print(f" WARNING: Image model training failed: {e}")
474
+ print(" Skipping image model β€” API will use heuristic fallback.")
475
+
476
+ try:
477
+ train_similarity_model(df)
478
+ except Exception as e:
479
+ print(f" WARNING: Similarity model failed: {e}")
480
+
481
+ try:
482
+ setup_chat_model()
483
+ except Exception as e:
484
+ print(f" WARNING: Chat model setup failed: {e}")
485
 
486
  elapsed = time.time() - t0
487
  print("\n" + "=" * 60)
requirements.txt CHANGED
@@ -7,7 +7,7 @@ httpx>=0.27.0
7
 
8
  # ── Classical ML (tabular β€” fraud, trust, anomaly) ───────────
9
  scikit-learn>=1.4.0
10
- xgboost>=2.0.0
11
  lightgbm>=4.0.0
12
  imbalanced-learn>=0.12.0
13
  joblib>=1.3.0
 
7
 
8
  # ── Classical ML (tabular β€” fraud, trust, anomaly) ───────────
9
  scikit-learn>=1.4.0
10
+ xgboost>=2.0.0,<3.0.0
11
  lightgbm>=4.0.0
12
  imbalanced-learn>=0.12.0
13
  joblib>=1.3.0