haroon103 commited on
Commit
b1fb42e
·
verified ·
1 Parent(s): 0c1ebbd

Upload app.py.py

Browse files
Files changed (1) hide show
  1. app.py.py +565 -0
app.py.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FYP4 SPAM DETECTION API
3
+ FastAPI application for email spam detection using DeBERTa and ViT models
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import io
9
+ import torch
10
+ import torch.nn as nn
11
+ from fastapi import FastAPI, File, UploadFile, HTTPException
12
+ from fastapi.responses import JSONResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel
15
+ from typing import Optional
16
+ import PyPDF2
17
+ import pdfplumber
18
+ from PIL import Image
19
+ from transformers import (
20
+ DebertaV2Model,
21
+ DebertaV2Tokenizer,
22
+ ViTModel,
23
+ ViTImageProcessor
24
+ )
25
+
26
+ # ================================
27
+ # CONFIGURATION
28
+ # ================================
29
+
30
+ class Config:
31
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ TEXT_MODEL = 'microsoft/deberta-v3-base'
33
+ IMAGE_MODEL = 'google/vit-base-patch16-224-in21k'
34
+ TEXT_HIDDEN_DIM = 768
35
+ IMAGE_HIDDEN_DIM = 768
36
+ FUSION_DIM = 512
37
+ NUM_CLASSES = 2
38
+ DROPOUT = 0.3
39
+ MAX_TEXT_LENGTH = 256
40
+ IMG_SIZE = 224
41
+
42
+ config = Config()
43
+
44
+ # ================================
45
+ # MODEL ARCHITECTURES
46
+ # ================================
47
+
48
+ class DeBERTaTextEncoder(nn.Module):
49
+ def __init__(self, dropout=0.3):
50
+ super(DeBERTaTextEncoder, self).__init__()
51
+ self.deberta = DebertaV2Model.from_pretrained(config.TEXT_MODEL)
52
+ self.projection = nn.Sequential(
53
+ nn.Dropout(dropout),
54
+ nn.Linear(config.TEXT_HIDDEN_DIM, config.FUSION_DIM),
55
+ nn.LayerNorm(config.FUSION_DIM),
56
+ nn.GELU()
57
+ )
58
+
59
+ def forward(self, input_ids, attention_mask):
60
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
61
+ pooled = outputs.last_hidden_state[:, 0, :]
62
+ return self.projection(pooled)
63
+
64
+
65
+ class ViTImageEncoder(nn.Module):
66
+ def __init__(self, dropout=0.3):
67
+ super(ViTImageEncoder, self).__init__()
68
+ self.vit = ViTModel.from_pretrained(config.IMAGE_MODEL)
69
+ self.projection = nn.Sequential(
70
+ nn.Dropout(dropout),
71
+ nn.Linear(config.IMAGE_HIDDEN_DIM, config.FUSION_DIM),
72
+ nn.LayerNorm(config.FUSION_DIM),
73
+ nn.GELU()
74
+ )
75
+
76
+ def forward(self, pixel_values):
77
+ outputs = self.vit(pixel_values=pixel_values, return_dict=True)
78
+ pooled = outputs.last_hidden_state[:, 0, :]
79
+ return self.projection(pooled)
80
+
81
+
82
+ class CrossModalAttention(nn.Module):
83
+ def __init__(self, dim=512, num_heads=8, dropout=0.1):
84
+ super(CrossModalAttention, self).__init__()
85
+ self.cross_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
86
+ self.norm1 = nn.LayerNorm(dim)
87
+ self.norm2 = nn.LayerNorm(dim)
88
+ self.ffn = nn.Sequential(
89
+ nn.Linear(dim, dim * 4),
90
+ nn.GELU(),
91
+ nn.Dropout(dropout),
92
+ nn.Linear(dim * 4, dim),
93
+ nn.Dropout(dropout)
94
+ )
95
+
96
+ def forward(self, text_features, image_features):
97
+ text_features = text_features.unsqueeze(1)
98
+ image_features = image_features.unsqueeze(1)
99
+ attn_output, _ = self.cross_attn(text_features, image_features, image_features)
100
+ fused = self.norm1(text_features + attn_output)
101
+ ffn_output = self.ffn(fused)
102
+ output = self.norm2(fused + ffn_output)
103
+ return output.squeeze(1)
104
+
105
+
106
+ class TextSpamClassifier(nn.Module):
107
+ def __init__(self, dropout=0.3):
108
+ super(TextSpamClassifier, self).__init__()
109
+ self.text_encoder = DeBERTaTextEncoder(dropout)
110
+ self.classifier = nn.Sequential(
111
+ nn.Linear(config.FUSION_DIM, 256),
112
+ nn.LayerNorm(256),
113
+ nn.GELU(),
114
+ nn.Dropout(dropout),
115
+ nn.Linear(256, 128),
116
+ nn.LayerNorm(128),
117
+ nn.GELU(),
118
+ nn.Dropout(dropout),
119
+ nn.Linear(128, config.NUM_CLASSES)
120
+ )
121
+
122
+ def forward(self, input_ids, attention_mask):
123
+ features = self.text_encoder(input_ids, attention_mask)
124
+ return self.classifier(features)
125
+
126
+
127
+ class ImageSpamClassifier(nn.Module):
128
+ def __init__(self, dropout=0.3):
129
+ super(ImageSpamClassifier, self).__init__()
130
+ self.image_encoder = ViTImageEncoder(dropout)
131
+ self.classifier = nn.Sequential(
132
+ nn.Linear(config.FUSION_DIM, 256),
133
+ nn.LayerNorm(256),
134
+ nn.GELU(),
135
+ nn.Dropout(dropout),
136
+ nn.Linear(256, 128),
137
+ nn.LayerNorm(128),
138
+ nn.GELU(),
139
+ nn.Dropout(dropout),
140
+ nn.Linear(128, config.NUM_CLASSES)
141
+ )
142
+
143
+ def forward(self, pixel_values):
144
+ features = self.image_encoder(pixel_values)
145
+ return self.classifier(features)
146
+
147
+
148
+ class FusionSpamClassifier(nn.Module):
149
+ def __init__(self, dropout=0.3):
150
+ super(FusionSpamClassifier, self).__init__()
151
+ self.text_encoder = DeBERTaTextEncoder(dropout)
152
+ self.image_encoder = ViTImageEncoder(dropout)
153
+ self.cross_modal_fusion = CrossModalAttention(config.FUSION_DIM, num_heads=8, dropout=dropout)
154
+ self.classifier = nn.Sequential(
155
+ nn.Linear(config.FUSION_DIM, 256),
156
+ nn.LayerNorm(256),
157
+ nn.GELU(),
158
+ nn.Dropout(dropout),
159
+ nn.Linear(256, 128),
160
+ nn.LayerNorm(128),
161
+ nn.GELU(),
162
+ nn.Dropout(dropout),
163
+ nn.Linear(128, config.NUM_CLASSES)
164
+ )
165
+
166
+ def forward(self, input_ids=None, attention_mask=None, pixel_values=None):
167
+ if input_ids is not None and pixel_values is not None:
168
+ text_features = self.text_encoder(input_ids, attention_mask)
169
+ image_features = self.image_encoder(pixel_values)
170
+ fused_features = self.cross_modal_fusion(text_features, image_features)
171
+ elif input_ids is not None:
172
+ fused_features = self.text_encoder(input_ids, attention_mask)
173
+ elif pixel_values is not None:
174
+ fused_features = self.image_encoder(pixel_values)
175
+ else:
176
+ raise ValueError("At least one modality required")
177
+ return self.classifier(fused_features)
178
+
179
+
180
+ # ================================
181
+ # PDF EXTRACTION
182
+ # ================================
183
+
184
+ class PDFExtractor:
185
+ @staticmethod
186
+ def extract_text_from_pdf(pdf_bytes):
187
+ """Extract text from PDF bytes"""
188
+ email_data = {
189
+ 'subject': '',
190
+ 'sender': '',
191
+ 'body': '',
192
+ 'full_text': ''
193
+ }
194
+
195
+ try:
196
+ pdf_file = io.BytesIO(pdf_bytes)
197
+ with pdfplumber.open(pdf_file) as pdf:
198
+ full_text = ""
199
+ for page in pdf.pages:
200
+ text = page.extract_text()
201
+ if text:
202
+ full_text += text + "\n"
203
+
204
+ email_data['full_text'] = full_text
205
+
206
+ patterns = {
207
+ 'subject': [r'Subject:\s*(.+)', r'SUBJECT:\s*(.+)'],
208
+ 'sender': [r'From:\s*(.+)', r'FROM:\s*(.+)']
209
+ }
210
+
211
+ for field, pattern_list in patterns.items():
212
+ for pattern in pattern_list:
213
+ match = re.search(pattern, full_text, re.IGNORECASE)
214
+ if match:
215
+ email_data[field] = match.group(1).strip()[:100]
216
+ break
217
+
218
+ body_match = re.search(r'(?:Subject|Date|From|To):.+?\n\n(.+)', full_text, re.DOTALL | re.IGNORECASE)
219
+ if body_match:
220
+ email_data['body'] = body_match.group(1).strip()
221
+ else:
222
+ email_data['body'] = full_text
223
+
224
+ return email_data
225
+ except Exception as e:
226
+ try:
227
+ pdf_file = io.BytesIO(pdf_bytes)
228
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
229
+ full_text = ""
230
+ for page in pdf_reader.pages:
231
+ text = page.extract_text()
232
+ if text:
233
+ full_text += text + "\n"
234
+ email_data['full_text'] = full_text
235
+ email_data['body'] = full_text
236
+ return email_data
237
+ except Exception as e:
238
+ raise HTTPException(status_code=400, detail=f"Error extracting text from PDF: {str(e)}")
239
+
240
+ @staticmethod
241
+ def extract_images_from_pdf(pdf_bytes):
242
+ """Extract first image from PDF bytes"""
243
+ try:
244
+ pdf_file = io.BytesIO(pdf_bytes)
245
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
246
+
247
+ for page_num, page in enumerate(pdf_reader.pages):
248
+ if '/XObject' in page['/Resources']:
249
+ xObject = page['/Resources']['/XObject'].get_object()
250
+ for obj in xObject:
251
+ if xObject[obj]['/Subtype'] == '/Image':
252
+ try:
253
+ size = (xObject[obj]['/Width'], xObject[obj]['/Height'])
254
+ data = xObject[obj].get_data()
255
+ mode = "RGB" if xObject[obj]['/ColorSpace'] == '/DeviceRGB' else "P"
256
+ img = Image.frombytes(mode, size, data)
257
+ return img
258
+ except:
259
+ continue
260
+ except:
261
+ pass
262
+ return None
263
+
264
+
265
+ # ================================
266
+ # TEXT PREPROCESSING
267
+ # ================================
268
+
269
+ def preprocess_text(text):
270
+ """Preprocess text for model input"""
271
+ text = str(text).lower()
272
+ text = re.sub(r'http\S+|www\.\S+', '[URL]', text)
273
+ text = re.sub(r'\S+@\S+', '[EMAIL]', text)
274
+ text = re.sub(r'\d+', '[NUM]', text)
275
+ text = re.sub(r'\s+', ' ', text).strip()
276
+ return text
277
+
278
+
279
+ # ================================
280
+ # SPAM DETECTOR
281
+ # ================================
282
+
283
+ class SpamDetector:
284
+ def __init__(self, text_model_path=None, image_model_path=None, fusion_model_path=None):
285
+ self.device = config.DEVICE
286
+ self.tokenizer = DebertaV2Tokenizer.from_pretrained(config.TEXT_MODEL)
287
+ self.image_processor = ViTImageProcessor.from_pretrained(config.IMAGE_MODEL)
288
+
289
+ self.text_model = None
290
+ self.image_model = None
291
+ self.fusion_model = None
292
+
293
+ # Load models
294
+ if text_model_path and os.path.exists(text_model_path):
295
+ print(f"Loading text model from {text_model_path}...")
296
+ self.text_model = TextSpamClassifier().to(self.device)
297
+ checkpoint = torch.load(text_model_path, map_location=self.device)
298
+ self.text_model.load_state_dict(checkpoint['model_state_dict'])
299
+ self.text_model.eval()
300
+ print("Text model loaded successfully!")
301
+
302
+ if image_model_path and os.path.exists(image_model_path):
303
+ print(f"Loading image model from {image_model_path}...")
304
+ self.image_model = ImageSpamClassifier().to(self.device)
305
+ checkpoint = torch.load(image_model_path, map_location=self.device)
306
+ self.image_model.load_state_dict(checkpoint['model_state_dict'])
307
+ self.image_model.eval()
308
+ print("Image model loaded successfully!")
309
+
310
+ if fusion_model_path and os.path.exists(fusion_model_path):
311
+ print(f"Loading fusion model from {fusion_model_path}...")
312
+ self.fusion_model = FusionSpamClassifier().to(self.device)
313
+ checkpoint = torch.load(fusion_model_path, map_location=self.device)
314
+ self.fusion_model.load_state_dict(checkpoint['model_state_dict'])
315
+ self.fusion_model.eval()
316
+ print("Fusion model loaded successfully!")
317
+
318
+ def predict_text(self, text):
319
+ if not self.text_model:
320
+ return None
321
+
322
+ encoding = self.tokenizer(
323
+ preprocess_text(text),
324
+ add_special_tokens=True,
325
+ max_length=config.MAX_TEXT_LENGTH,
326
+ padding='max_length',
327
+ truncation=True,
328
+ return_attention_mask=True,
329
+ return_tensors='pt'
330
+ )
331
+
332
+ input_ids = encoding['input_ids'].to(self.device)
333
+ attention_mask = encoding['attention_mask'].to(self.device)
334
+
335
+ with torch.no_grad():
336
+ outputs = self.text_model(input_ids, attention_mask)
337
+ probs = torch.softmax(outputs, dim=1)
338
+ predicted = torch.argmax(probs, dim=1)
339
+
340
+ return {
341
+ 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE',
342
+ 'confidence': float(probs[0, predicted.item()].item() * 100),
343
+ 'spam_probability': float(probs[0, 1].item() * 100),
344
+ 'ham_probability': float(probs[0, 0].item() * 100)
345
+ }
346
+
347
+ def predict_image(self, image):
348
+ if not self.image_model or image is None:
349
+ return None
350
+
351
+ try:
352
+ inputs = self.image_processor(images=image, return_tensors='pt')
353
+ pixel_values = inputs['pixel_values'].to(self.device)
354
+
355
+ with torch.no_grad():
356
+ outputs = self.image_model(pixel_values)
357
+ probs = torch.softmax(outputs, dim=1)
358
+ predicted = torch.argmax(probs, dim=1)
359
+
360
+ return {
361
+ 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE',
362
+ 'confidence': float(probs[0, predicted.item()].item() * 100),
363
+ 'spam_probability': float(probs[0, 1].item() * 100),
364
+ 'ham_probability': float(probs[0, 0].item() * 100)
365
+ }
366
+ except Exception as e:
367
+ return {'error': str(e)}
368
+
369
+ def predict_fusion(self, text, image=None):
370
+ if not self.fusion_model:
371
+ return None
372
+
373
+ encoding = self.tokenizer(
374
+ preprocess_text(text),
375
+ add_special_tokens=True,
376
+ max_length=config.MAX_TEXT_LENGTH,
377
+ padding='max_length',
378
+ truncation=True,
379
+ return_attention_mask=True,
380
+ return_tensors='pt'
381
+ )
382
+
383
+ input_ids = encoding['input_ids'].to(self.device)
384
+ attention_mask = encoding['attention_mask'].to(self.device)
385
+
386
+ pixel_values = None
387
+ if image is not None:
388
+ try:
389
+ image_inputs = self.image_processor(images=image, return_tensors='pt')
390
+ pixel_values = image_inputs['pixel_values'].to(self.device)
391
+ except:
392
+ pass
393
+
394
+ with torch.no_grad():
395
+ outputs = self.fusion_model(input_ids, attention_mask, pixel_values)
396
+ probs = torch.softmax(outputs, dim=1)
397
+ predicted = torch.argmax(probs, dim=1)
398
+
399
+ return {
400
+ 'prediction': 'SPAM' if predicted.item() == 1 else 'LEGITIMATE',
401
+ 'confidence': float(probs[0, predicted.item()].item() * 100),
402
+ 'spam_probability': float(probs[0, 1].item() * 100),
403
+ 'ham_probability': float(probs[0, 0].item() * 100)
404
+ }
405
+
406
+
407
+ # ================================
408
+ # FASTAPI APPLICATION
409
+ # ================================
410
+
411
+ app = FastAPI(
412
+ title="FYP4 Spam Detection API",
413
+ description="Email spam detection using DeBERTa and ViT models",
414
+ version="1.0.0"
415
+ )
416
+
417
+ # Add CORS middleware
418
+ app.add_middleware(
419
+ CORSMiddleware,
420
+ allow_origins=["*"],
421
+ allow_credentials=True,
422
+ allow_methods=["*"],
423
+ allow_headers=["*"],
424
+ )
425
+
426
+ # Initialize detector (models will be loaded on startup)
427
+ detector = None
428
+
429
+
430
+ @app.on_event("startup")
431
+ async def startup_event():
432
+ """Load models on startup"""
433
+ global detector
434
+
435
+ text_model_path = os.getenv("TEXT_MODEL_PATH", "models/text_model.pth")
436
+ image_model_path = os.getenv("IMAGE_MODEL_PATH", "models/image_model.pth")
437
+ fusion_model_path = os.getenv("FUSION_MODEL_PATH", "models/fusion_model.pth")
438
+
439
+ # Check which models exist
440
+ text_exists = os.path.exists(text_model_path)
441
+ image_exists = os.path.exists(image_model_path)
442
+ fusion_exists = os.path.exists(fusion_model_path)
443
+
444
+ print(f"Models availability: Text={text_exists}, Image={image_exists}, Fusion={fusion_exists}")
445
+
446
+ detector = SpamDetector(
447
+ text_model_path=text_model_path if text_exists else None,
448
+ image_model_path=image_model_path if image_exists else None,
449
+ fusion_model_path=fusion_model_path if fusion_exists else None
450
+ )
451
+
452
+ print("API ready!")
453
+
454
+
455
+ # Pydantic models for request/response
456
+ class TextRequest(BaseModel):
457
+ text: str
458
+
459
+
460
+ class PredictionResponse(BaseModel):
461
+ prediction: str
462
+ confidence: float
463
+ spam_probability: float
464
+ ham_probability: float
465
+ model_used: str
466
+
467
+
468
+ class PDFPredictionResponse(BaseModel):
469
+ email_data: dict
470
+ text_result: Optional[dict]
471
+ image_result: Optional[dict]
472
+ fusion_result: Optional[dict]
473
+ final_prediction: str
474
+ final_confidence: float
475
+
476
+
477
+ @app.get("/")
478
+ async def root():
479
+ """Root endpoint with API information"""
480
+ return {
481
+ "message": "FYP4 Spam Detection API",
482
+ "version": "1.0.0",
483
+ "endpoints": {
484
+ "POST /predict/text": "Predict spam from text",
485
+ "POST /predict/pdf": "Predict spam from PDF email",
486
+ "GET /health": "Health check"
487
+ }
488
+ }
489
+
490
+
491
+ @app.get("/health")
492
+ async def health_check():
493
+ """Health check endpoint"""
494
+ return {
495
+ "status": "healthy",
496
+ "device": str(config.DEVICE),
497
+ "models_loaded": {
498
+ "text": detector.text_model is not None if detector else False,
499
+ "image": detector.image_model is not None if detector else False,
500
+ "fusion": detector.fusion_model is not None if detector else False
501
+ }
502
+ }
503
+
504
+
505
+ @app.post("/predict/text", response_model=PredictionResponse)
506
+ async def predict_text(request: TextRequest):
507
+ """Predict spam from text content"""
508
+ if not detector or not detector.text_model:
509
+ raise HTTPException(status_code=503, detail="Text model not available")
510
+
511
+ result = detector.predict_text(request.text)
512
+ result['model_used'] = 'text'
513
+
514
+ return result
515
+
516
+
517
+ @app.post("/predict/pdf", response_model=PDFPredictionResponse)
518
+ async def predict_pdf(file: UploadFile = File(...)):
519
+ """Predict spam from PDF email"""
520
+ if not file.filename.endswith('.pdf'):
521
+ raise HTTPException(status_code=400, detail="File must be a PDF")
522
+
523
+ if not detector:
524
+ raise HTTPException(status_code=503, detail="Models not loaded")
525
+
526
+ # Read PDF
527
+ pdf_bytes = await file.read()
528
+
529
+ # Extract text and images
530
+ email_data = PDFExtractor.extract_text_from_pdf(pdf_bytes)
531
+ full_text = f"{email_data['subject']}\n\n{email_data['body']}"
532
+ image = PDFExtractor.extract_images_from_pdf(pdf_bytes)
533
+
534
+ # Get predictions
535
+ results = {
536
+ 'email_data': email_data,
537
+ 'text_result': None,
538
+ 'image_result': None,
539
+ 'fusion_result': None
540
+ }
541
+
542
+ if detector.text_model:
543
+ results['text_result'] = detector.predict_text(full_text)
544
+
545
+ if detector.image_model and image:
546
+ results['image_result'] = detector.predict_image(image)
547
+
548
+ if detector.fusion_model:
549
+ results['fusion_result'] = detector.predict_fusion(full_text, image)
550
+
551
+ # Determine final prediction (prioritize: fusion > text > image)
552
+ final_result = results['fusion_result'] or results['text_result'] or results['image_result']
553
+
554
+ if not final_result:
555
+ raise HTTPException(status_code=503, detail="No models available for prediction")
556
+
557
+ results['final_prediction'] = final_result['prediction']
558
+ results['final_confidence'] = final_result['confidence']
559
+
560
+ return results
561
+
562
+
563
+ if __name__ == "__main__":
564
+ import uvicorn
565
+ uvicorn.run(app, host="0.0.0.0", port=7860)