Neel2601 commited on
Commit
68c1f1a
Β·
verified Β·
1 Parent(s): 0eab43d

Upload api_complete.py

Browse files
Files changed (1) hide show
  1. api_complete.py +116 -222
api_complete.py CHANGED
@@ -494,223 +494,75 @@ async def load_all_models():
494
 
495
  # Analyze the structure to understand the architecture
496
  if isinstance(crop_model_data, dict):
497
- logger.info(f"Crop model keys: {list(crop_model_data.keys())[:10]}...") # Show first 10 keys
 
498
 
499
- # Show ALL classifier keys to understand the structure
500
- classifier_keys = [k for k in crop_model_data.keys() if 'classifier' in k]
501
- logger.info(f"πŸ“‹ Classifier keys in checkpoint: {classifier_keys}")
502
- logger.info(f"πŸ“‹ Total checkpoint keys: {len(crop_model_data.keys())}")
503
-
504
- # Create the EXACT architecture matching checkpoint
505
- class CustomEfficientNet(torch.nn.Module):
506
- def __init__(self, num_classes):
507
- super().__init__()
508
- # EXACT architecture matching your state dict keys
509
- self.features = torch.nn.Sequential(
510
- torch.nn.Conv2d(3, 32, 3, padding=1), # features.0
511
- torch.nn.BatchNorm2d(32), # features.1
512
- torch.nn.ReLU(inplace=True), # features.2
513
- torch.nn.MaxPool2d(2), # features.3
514
- torch.nn.Conv2d(32, 64, 3, padding=1), # features.4
515
- torch.nn.BatchNorm2d(64), # features.5
516
- torch.nn.ReLU(inplace=True), # features.6
517
- torch.nn.MaxPool2d(2), # features.7
518
- torch.nn.Conv2d(64, 128, 3, padding=1), # features.8
519
- torch.nn.BatchNorm2d(128), # features.9
520
- torch.nn.ReLU(inplace=True), # features.10
521
- torch.nn.MaxPool2d(2), # features.11
522
- torch.nn.Conv2d(128, 256, 3, padding=1), # features.12
523
- torch.nn.BatchNorm2d(256), # features.13
524
- torch.nn.ReLU(inplace=True), # features.14
525
- torch.nn.MaxPool2d(2), # features.15
526
- torch.nn.Conv2d(256, 512, 3, padding=1), # features.16
527
- torch.nn.BatchNorm2d(512), # features.17
528
- torch.nn.ReLU(inplace=True), # features.18
529
- )
530
-
531
- # Classifier matching checkpoint EXACTLY
532
- # Checkpoint has Linear layers at indices 3 and 6
533
- # ['classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']
534
- self.classifier = torch.nn.Sequential(
535
- torch.nn.Dropout(0.5), # classifier.0 (no weights)
536
- torch.nn.ReLU(inplace=True), # classifier.1 (no weights)
537
- torch.nn.Dropout(0.5), # classifier.2 (no weights)
538
- torch.nn.Linear(25088, 1024), # classifier.3 - HAS WEIGHTS βœ…
539
- torch.nn.ReLU(inplace=True), # classifier.4 (no weights)
540
- torch.nn.Dropout(0.5), # classifier.5 (no weights)
541
- torch.nn.Linear(1024, num_classes) # classifier.6 - HAS WEIGHTS βœ…
542
- )
543
-
544
- def forward(self, x):
545
- x = self.features(x)
546
- x = torch.flatten(x, 1)
547
- x = self.classifier(x)
548
- return x
549
-
550
- # Create models with correct number of classes
551
- num_crop_classes = len(crop_classes['classes'])
552
- num_disease_classes = len(disease_classes['classes'])
553
-
554
- logger.info(f"🎯 Creating crop model with {num_crop_classes} classes")
555
- logger.info(f"🎯 Creating disease model with {num_disease_classes} classes")
556
-
557
- crop_model = CustomEfficientNet(num_crop_classes)
558
- disease_model = CustomEfficientNet(num_disease_classes)
559
-
560
- # Load state dicts with strict=False to handle any mismatches
561
- logger.info("πŸ“Š Loading crop model state_dict...")
562
- crop_result = crop_model.load_state_dict(crop_model_data, strict=False)
563
- crop_missing = list(crop_result.missing_keys) if hasattr(crop_result, 'missing_keys') else []
564
- crop_unexpected = list(crop_result.unexpected_keys) if hasattr(crop_result, 'unexpected_keys') else []
565
-
566
- if crop_unexpected:
567
- logger.info(f"πŸ“Š Crop model unexpected keys: {crop_unexpected[:5]}...")
568
-
569
- logger.info("πŸ“Š Loading disease model state_dict...")
570
- disease_result = disease_model.load_state_dict(disease_model_data, strict=False)
571
- disease_missing = list(disease_result.missing_keys) if hasattr(disease_result, 'missing_keys') else []
572
- disease_unexpected = list(disease_result.unexpected_keys) if hasattr(disease_result, 'unexpected_keys') else []
573
-
574
- # Filter out non-critical missing keys
575
- non_critical_patterns = ['num_batches_tracked', 'running_mean', 'running_var']
576
-
577
- def filter_critical_keys(missing_keys):
578
- if not missing_keys:
579
- return []
580
- return [k for k in missing_keys if not any(pattern in k for pattern in non_critical_patterns)]
581
-
582
- crop_critical = filter_critical_keys(crop_missing)
583
- disease_critical = filter_critical_keys(disease_missing)
584
-
585
- # Log missing keys
586
- if crop_missing:
587
- logger.info(f"πŸ“Š Crop model missing keys: {len(crop_missing)} (filtering non-critical...)")
588
- if disease_missing:
589
- logger.info(f"πŸ“Š Disease model missing keys: {len(disease_missing)} (filtering non-critical...)")
590
-
591
- # AUTO-FIX: Rebuild classifier if critical keys are missing
592
- def fix_missing_classifier(model, missing_keys, model_name, num_classes):
593
- """Automatically fix missing classifier keys by rebuilding the layer"""
594
- classifier_missing = [k for k in missing_keys if 'classifier' in k]
595
-
596
- if classifier_missing:
597
- logger.warning(f"⚠️ {model_name} missing classifier keys: {classifier_missing}")
598
- logger.info(f"πŸ”§ Auto-fixing: Rebuilding classifier for {num_classes} classes...")
599
-
600
- # Rebuild classifier matching checkpoint architecture EXACTLY
601
- # Checkpoint has Linear layers at indices 3 and 6
602
- model.classifier = torch.nn.Sequential(
603
- torch.nn.Dropout(0.5), # classifier.0 (no weights)
604
- torch.nn.ReLU(inplace=True), # classifier.1 (no weights)
605
- torch.nn.Dropout(0.5), # classifier.2 (no weights)
606
- torch.nn.Linear(25088, 1024), # classifier.3 - HAS WEIGHTS βœ…
607
- torch.nn.ReLU(inplace=True), # classifier.4 (no weights)
608
- torch.nn.Dropout(0.5), # classifier.5 (no weights)
609
- torch.nn.Linear(1024, num_classes) # classifier.6 - HAS WEIGHTS βœ…
610
- )
611
 
612
- logger.info(f"βœ… {model_name} classifier rebuilt with output: {num_classes} classes")
 
 
 
 
613
 
614
- # Try to reload state_dict after fixing
 
 
 
 
 
 
615
  try:
616
- model.load_state_dict(crop_model_data if 'crop' in model_name.lower() else disease_model_data, strict=False)
617
- logger.info(f"βœ… {model_name} state_dict reloaded after classifier fix")
618
- except Exception as reload_error:
619
- logger.warning(f"⚠️ Could not reload state_dict: {reload_error}")
620
-
621
- return model
 
 
 
 
622
 
623
- # Fix missing classifier keys automatically
624
- if crop_critical:
625
- crop_model = fix_missing_classifier(crop_model, crop_critical, "Crop model", num_crop_classes)
626
- # Re-check for remaining critical keys
627
- crop_critical_remaining = [k for k in crop_critical if 'classifier' not in k]
628
- if crop_critical_remaining:
629
- logger.warning(f"⚠️ Crop model still missing keys: {crop_critical_remaining}")
630
 
631
- if disease_critical:
632
- disease_model = fix_missing_classifier(disease_model, disease_critical, "Disease model", num_disease_classes)
633
- # Re-check for remaining critical keys
634
- disease_critical_remaining = [k for k in disease_critical if 'classifier' not in k]
635
- if disease_critical_remaining:
636
- logger.warning(f"⚠️ Disease model still missing keys: {disease_critical_remaining}")
637
 
638
- # Move models to device with optimal dtype
639
  vision_dtype = torch.float16 if vision_device == "cuda" else torch.float32
640
- logger.info(f"πŸ“Š Vision models dtype: {vision_dtype}")
641
-
642
  crop_model.to(vision_device, dtype=vision_dtype)
643
  disease_model.to(vision_device, dtype=vision_dtype)
644
 
645
- # Set to evaluation mode
646
  crop_model.eval()
647
  disease_model.eval()
648
-
649
- logger.info(f"βœ… Crop model moved to {vision_device} ({vision_dtype})")
650
- logger.info(f"βœ… Disease model moved to {vision_device} ({vision_dtype})")
651
-
652
- # Log final classifier shapes
653
- crop_classifier_shape = list(crop_model.classifier[-1].weight.shape)
654
- disease_classifier_shape = list(disease_model.classifier[-1].weight.shape)
655
- logger.info(f"πŸ“Š Crop classifier output shape: {crop_classifier_shape}")
656
- logger.info(f"πŸ“Š Disease classifier output shape: {disease_classifier_shape}")
657
-
658
  else:
659
- # If they're complete model objects, use them directly
660
  crop_model = crop_model_data
661
  disease_model = disease_model_data
662
- if hasattr(crop_model, 'eval'):
663
- crop_model.eval()
664
- if hasattr(disease_model, 'eval'):
665
- disease_model.eval()
 
 
666
 
667
- logger.info(f"βœ… Vision models loaded successfully on {vision_device}!")
668
- logger.info(f"🎯 Crop: {len(crop_classes.get('classes', []))} classes")
669
- logger.info(f"🎯 Disease: {len(disease_classes.get('classes', []))} classes")
670
-
671
  except Exception as e:
672
  logger.error(f"❌ Failed to load vision models: {e}")
673
-
674
- # Try CPU fallback if GPU failed
675
- if vision_device == "cuda":
676
- try:
677
- logger.warning("⚠️ GPU loading failed, trying CPU fallback...")
678
- vision_device = "cpu"
679
-
680
- # Retry loading on CPU
681
- crop_model_data = torch.load(crop_path / "best_model.pth", map_location='cpu')
682
- disease_model_data = torch.load(disease_path / "best_model.pth", map_location='cpu')
683
-
684
- crop_model = CustomEfficientNet(len(crop_classes['classes']))
685
- disease_model = CustomEfficientNet(len(disease_classes['classes']))
686
-
687
- crop_model.load_state_dict(crop_model_data, strict=False)
688
- disease_model.load_state_dict(disease_model_data, strict=False)
689
-
690
- crop_model.to(vision_device, dtype=torch.float32)
691
- disease_model.to(vision_device, dtype=torch.float32)
692
- crop_model.eval()
693
- disease_model.eval()
694
-
695
- logger.info(f"βœ… Vision models loaded on CPU fallback")
696
- except Exception as fallback_error:
697
- logger.error(f"❌ CPU fallback also failed: {fallback_error}")
698
- import traceback
699
- logger.error(f"Full error: {traceback.format_exc()}")
700
- # Create dummy models for testing
701
- crop_model = None
702
- disease_model = None
703
- crop_classes = {'classes': ['tomato', 'potato', 'wheat', 'rice']}
704
- disease_classes = {'classes': ['healthy', 'early_blight', 'late_blight', 'leaf_spot']}
705
- else:
706
- import traceback
707
- logger.error(f"Full error: {traceback.format_exc()}")
708
- # Create dummy models for testing
709
- crop_model = None
710
- disease_model = None
711
- crop_classes = {'classes': ['tomato', 'potato', 'wheat', 'rice']}
712
- disease_classes = {'classes': ['healthy', 'early_blight', 'late_blight', 'leaf_spot']}
713
-
714
  # Define image transforms - match training resolution
715
  # 100352 / 512 = 196, so sqrt(196) = 14, meaning 14x14 feature map
716
  # This suggests input should be smaller to get 25088 features
@@ -752,7 +604,8 @@ async def load_all_models():
752
 
753
  # Load model with joblib
754
  market_model = joblib.load(model_path)
755
- logger.info(f"Loaded {model_path.name} with joblib")
 
756
 
757
  # Load encoders with joblib
758
  encoders = joblib.load(encoders_path)
@@ -1355,7 +1208,7 @@ async def agricultural_chat(
1355
  raise HTTPException(status_code=500, detail=str(e))
1356
 
1357
  @app.post("/image-diagnosis")
1358
- async def image_diagnosis(image_file: UploadFile = File(...)):
1359
  """Diagnose crop diseases from images"""
1360
  try:
1361
  if 'vision' not in models or models['vision'] is None:
@@ -1412,12 +1265,27 @@ async def image_diagnosis(image_file: UploadFile = File(...)):
1412
 
1413
  with torch.no_grad():
1414
  outputs = disease_model(image_tensor)
1415
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
1416
- predicted_class = torch.argmax(probabilities, dim=1).item()
1417
- confidence = probabilities[0][predicted_class].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1418
 
1419
- disease_name = disease_classes['classes'][predicted_class]
1420
  logger.info(f"Disease prediction: {disease_name} (confidence: {confidence:.3f})")
 
1421
  except Exception as model_error:
1422
  logger.error(f"Model inference error: {model_error}")
1423
  # Fallback when model inference fails
@@ -1431,32 +1299,58 @@ async def image_diagnosis(image_file: UploadFile = File(...)):
1431
  confidence = random.uniform(0.7, 0.95)
1432
 
1433
  # Generate treatment recommendations
1434
- treatments = {
1435
- "Early_Blight": [
1436
- "Remove infected leaves immediately",
1437
- "Apply Mancozeb fungicide (2.5g/L water)",
1438
- "Spray Trichoderma solution weekly",
1439
- "Improve air circulation"
1440
- ],
1441
- "Late_Blight": [
1442
- "Apply Copper oxychloride spray",
1443
- "Remove infected plants",
1444
- "Avoid overhead watering",
1445
- "Use resistant varieties"
1446
- ]
1447
- }
1448
 
1449
- treatment = treatments.get(disease_name, [
 
1450
  "Consult agricultural expert",
1451
  "Apply appropriate fungicide",
1452
  "Maintain proper plant hygiene"
1453
- ])
1454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1455
  return {
1456
  "disease": disease_name,
1457
  "confidence": round(confidence * 100, 2),
1458
  "treatment": treatment,
1459
- "cause": "Fungal infection caused by environmental conditions"
 
 
1460
  }
1461
 
1462
  except Exception as e:
@@ -1691,5 +1585,5 @@ async def text_to_speech(
1691
 
1692
  if __name__ == "__main__":
1693
  import uvicorn
1694
- uvicorn.run(app, host="0.0.0.0", port=7860)
1695
 
 
494
 
495
  # Analyze the structure to understand the architecture
496
  if isinstance(crop_model_data, dict):
497
+ # Use standard EfficientNet-B0 from torchvision to match training
498
+ from torchvision import models as tv_models
499
 
500
+ def load_efficientnet(state_dict, num_classes, model_name):
501
+ try:
502
+ logger.info(f"πŸ—οΈ Building EfficientNet-B0 for {model_name}...")
503
+ # 1. Init standard model
504
+ model = tv_models.efficientnet_b0(weights=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
 
506
+ # 2. Modify classifier to match num_classes (1280 -> num_classes)
507
+ # EfficientNet-B0 classifier is: Sequential(Dropout, Linear(1280, 1000))
508
+ # We need to change the Linear layer at index 1
509
+ in_features = model.classifier[1].in_features
510
+ model.classifier[1] = torch.nn.Linear(in_features, num_classes)
511
 
512
+ # 3. Load weights
513
+ msg = model.load_state_dict(state_dict, strict=True)
514
+ logger.info(f"βœ… {model_name} loaded successfully (Strict=True)")
515
+ return model
516
+ except Exception as e:
517
+ logger.warning(f"⚠️ Strict loading failed for {model_name}: {e}")
518
+ logger.info("πŸ”„ Retrying with strict=False...")
519
  try:
520
+ model.load_state_dict(state_dict, strict=False)
521
+ logger.info(f"βœ… {model_name} loaded (Strict=False)")
522
+ return model
523
+ except Exception as e2:
524
+ logger.error(f"❌ Failed to load {model_name}: {e2}")
525
+ raise e2
526
+
527
+ # Create and load models
528
+ num_crop_classes = len(crop_classes['classes'])
529
+ num_disease_classes = len(disease_classes['classes'])
530
 
531
+ logger.info(f"🎯 Loading Crop Model ({num_crop_classes} classes)...")
532
+ crop_model = load_efficientnet(crop_model_data, num_crop_classes, "Crop Model")
 
 
 
 
 
533
 
534
+ logger.info(f"🎯 Loading Disease Model ({num_disease_classes} classes)...")
535
+ disease_model = load_efficientnet(disease_model_data, num_disease_classes, "Disease Model")
 
 
 
 
536
 
537
+ # Move to device
538
  vision_dtype = torch.float16 if vision_device == "cuda" else torch.float32
 
 
539
  crop_model.to(vision_device, dtype=vision_dtype)
540
  disease_model.to(vision_device, dtype=vision_dtype)
541
 
 
542
  crop_model.eval()
543
  disease_model.eval()
544
+
 
 
 
 
 
 
 
 
 
545
  else:
546
+ # If they are already model objects (legacy support)
547
  crop_model = crop_model_data
548
  disease_model = disease_model_data
549
+ crop_model.to(vision_device)
550
+ disease_model.to(vision_device)
551
+ crop_model.eval()
552
+ disease_model.eval()
553
+
554
+ logger.info(f"βœ… Vision models loaded successfully on {vision_device}!")
555
 
 
 
 
 
556
  except Exception as e:
557
  logger.error(f"❌ Failed to load vision models: {e}")
558
+ import traceback
559
+ logger.error(f"Full error: {traceback.format_exc()}")
560
+ # Create dummy models for testing
561
+ crop_model = None
562
+ disease_model = None
563
+ crop_classes = {'classes': ['tomato', 'potato', 'wheat', 'rice']}
564
+ disease_classes = {'classes': ['healthy', 'early_blight', 'late_blight', 'leaf_spot']}
565
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  # Define image transforms - match training resolution
567
  # 100352 / 512 = 196, so sqrt(196) = 14, meaning 14x14 feature map
568
  # This suggests input should be smaller to get 25088 features
 
604
 
605
  # Load model with joblib
606
  market_model = joblib.load(model_path)
607
+ from pathlib import Path
608
+ logger.info(f"Loaded {Path(model_path).name} with joblib")
609
 
610
  # Load encoders with joblib
611
  encoders = joblib.load(encoders_path)
 
1208
  raise HTTPException(status_code=500, detail=str(e))
1209
 
1210
  @app.post("/image-diagnosis")
1211
+ async def image_diagnosis(image_file: UploadFile = File(...), language: str = Form("en")):
1212
  """Diagnose crop diseases from images"""
1213
  try:
1214
  if 'vision' not in models or models['vision'] is None:
 
1265
 
1266
  with torch.no_grad():
1267
  outputs = disease_model(image_tensor)
1268
+ # Get Top 3
1269
+ probs = torch.nn.functional.softmax(outputs, dim=1)
1270
+ top3_prob, top3_idx = torch.topk(probs, min(3, len(disease_classes['classes']))) # Ensure we don't ask for more than available classes
1271
+
1272
+ # Primary prediction (Top 1)
1273
+ top1_idx = top3_idx[0][0].item()
1274
+ confidence = top3_prob[0][0].item()
1275
+ disease_name = disease_classes['classes'][top1_idx]
1276
+
1277
+ # Top 3 List
1278
+ predictions = []
1279
+ for i in range(top3_idx.shape[1]):
1280
+ idx = top3_idx[0][i].item()
1281
+ prob = top3_prob[0][i].item()
1282
+ predictions.append({
1283
+ "disease": disease_classes['classes'][idx],
1284
+ "confidence": round(prob * 100, 2)
1285
+ })
1286
 
 
1287
  logger.info(f"Disease prediction: {disease_name} (confidence: {confidence:.3f})")
1288
+ logger.info(f"Top 3 predictions: {predictions}")
1289
  except Exception as model_error:
1290
  logger.error(f"Model inference error: {model_error}")
1291
  # Fallback when model inference fails
 
1299
  confidence = random.uniform(0.7, 0.95)
1300
 
1301
  # Generate treatment recommendations
1302
+ try:
1303
+ import json
1304
+ info_path = Path(__file__).parent / "disease_info.json"
1305
+ if info_path.exists():
1306
+ with open(info_path, 'r') as f:
1307
+ disease_info = json.load(f)
1308
+ else:
1309
+ disease_info = {}
1310
+ except Exception:
1311
+ disease_info = {}
1312
+
1313
+ info = disease_info.get(disease_name, {})
 
 
1314
 
1315
+ # Default fallback if disease not in JSON
1316
+ default_treatment = [
1317
  "Consult agricultural expert",
1318
  "Apply appropriate fungicide",
1319
  "Maintain proper plant hygiene"
1320
+ ]
1321
 
1322
+ treatment = info.get("treatment", default_treatment)
1323
+ cause = info.get("cause", "Fungal or bacterial infection caused by environmental conditions")
1324
+ prevention = info.get("prevention", "Use resistant varieties and practice crop rotation")
1325
+
1326
+ # --- TRANSLATION LOGIC ---
1327
+ if language != "en":
1328
+ try:
1329
+ from deep_translator import GoogleTranslator
1330
+ translator = GoogleTranslator(source='auto', target=language)
1331
+
1332
+ # Translate Cause
1333
+ cause = translator.translate(cause)
1334
+
1335
+ # Translate Prevention
1336
+ prevention = translator.translate(prevention)
1337
+
1338
+ # Translate Treatment List
1339
+ translated_treatments = []
1340
+ for t in treatment:
1341
+ translated_treatments.append(translator.translate(t))
1342
+ treatment = translated_treatments
1343
+
1344
+ except Exception as trans_e:
1345
+ logger.error(f"Translation failed: {trans_e}")
1346
+
1347
  return {
1348
  "disease": disease_name,
1349
  "confidence": round(confidence * 100, 2),
1350
  "treatment": treatment,
1351
+ "cause": cause,
1352
+ "prevention": prevention,
1353
+ "top_3_predictions": predictions
1354
  }
1355
 
1356
  except Exception as e:
 
1585
 
1586
  if __name__ == "__main__":
1587
  import uvicorn
1588
+ uvicorn.run(app, host="0.0.0.0", port=8000)
1589