Spaces:
Running
Running
Upload api_complete.py
Browse files- api_complete.py +116 -222
api_complete.py
CHANGED
|
@@ -494,223 +494,75 @@ async def load_all_models():
|
|
| 494 |
|
| 495 |
# Analyze the structure to understand the architecture
|
| 496 |
if isinstance(crop_model_data, dict):
|
| 497 |
-
|
|
|
|
| 498 |
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
# Create the EXACT architecture matching checkpoint
|
| 505 |
-
class CustomEfficientNet(torch.nn.Module):
|
| 506 |
-
def __init__(self, num_classes):
|
| 507 |
-
super().__init__()
|
| 508 |
-
# EXACT architecture matching your state dict keys
|
| 509 |
-
self.features = torch.nn.Sequential(
|
| 510 |
-
torch.nn.Conv2d(3, 32, 3, padding=1), # features.0
|
| 511 |
-
torch.nn.BatchNorm2d(32), # features.1
|
| 512 |
-
torch.nn.ReLU(inplace=True), # features.2
|
| 513 |
-
torch.nn.MaxPool2d(2), # features.3
|
| 514 |
-
torch.nn.Conv2d(32, 64, 3, padding=1), # features.4
|
| 515 |
-
torch.nn.BatchNorm2d(64), # features.5
|
| 516 |
-
torch.nn.ReLU(inplace=True), # features.6
|
| 517 |
-
torch.nn.MaxPool2d(2), # features.7
|
| 518 |
-
torch.nn.Conv2d(64, 128, 3, padding=1), # features.8
|
| 519 |
-
torch.nn.BatchNorm2d(128), # features.9
|
| 520 |
-
torch.nn.ReLU(inplace=True), # features.10
|
| 521 |
-
torch.nn.MaxPool2d(2), # features.11
|
| 522 |
-
torch.nn.Conv2d(128, 256, 3, padding=1), # features.12
|
| 523 |
-
torch.nn.BatchNorm2d(256), # features.13
|
| 524 |
-
torch.nn.ReLU(inplace=True), # features.14
|
| 525 |
-
torch.nn.MaxPool2d(2), # features.15
|
| 526 |
-
torch.nn.Conv2d(256, 512, 3, padding=1), # features.16
|
| 527 |
-
torch.nn.BatchNorm2d(512), # features.17
|
| 528 |
-
torch.nn.ReLU(inplace=True), # features.18
|
| 529 |
-
)
|
| 530 |
-
|
| 531 |
-
# Classifier matching checkpoint EXACTLY
|
| 532 |
-
# Checkpoint has Linear layers at indices 3 and 6
|
| 533 |
-
# ['classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']
|
| 534 |
-
self.classifier = torch.nn.Sequential(
|
| 535 |
-
torch.nn.Dropout(0.5), # classifier.0 (no weights)
|
| 536 |
-
torch.nn.ReLU(inplace=True), # classifier.1 (no weights)
|
| 537 |
-
torch.nn.Dropout(0.5), # classifier.2 (no weights)
|
| 538 |
-
torch.nn.Linear(25088, 1024), # classifier.3 - HAS WEIGHTS β
|
| 539 |
-
torch.nn.ReLU(inplace=True), # classifier.4 (no weights)
|
| 540 |
-
torch.nn.Dropout(0.5), # classifier.5 (no weights)
|
| 541 |
-
torch.nn.Linear(1024, num_classes) # classifier.6 - HAS WEIGHTS β
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
def forward(self, x):
|
| 545 |
-
x = self.features(x)
|
| 546 |
-
x = torch.flatten(x, 1)
|
| 547 |
-
x = self.classifier(x)
|
| 548 |
-
return x
|
| 549 |
-
|
| 550 |
-
# Create models with correct number of classes
|
| 551 |
-
num_crop_classes = len(crop_classes['classes'])
|
| 552 |
-
num_disease_classes = len(disease_classes['classes'])
|
| 553 |
-
|
| 554 |
-
logger.info(f"π― Creating crop model with {num_crop_classes} classes")
|
| 555 |
-
logger.info(f"π― Creating disease model with {num_disease_classes} classes")
|
| 556 |
-
|
| 557 |
-
crop_model = CustomEfficientNet(num_crop_classes)
|
| 558 |
-
disease_model = CustomEfficientNet(num_disease_classes)
|
| 559 |
-
|
| 560 |
-
# Load state dicts with strict=False to handle any mismatches
|
| 561 |
-
logger.info("π Loading crop model state_dict...")
|
| 562 |
-
crop_result = crop_model.load_state_dict(crop_model_data, strict=False)
|
| 563 |
-
crop_missing = list(crop_result.missing_keys) if hasattr(crop_result, 'missing_keys') else []
|
| 564 |
-
crop_unexpected = list(crop_result.unexpected_keys) if hasattr(crop_result, 'unexpected_keys') else []
|
| 565 |
-
|
| 566 |
-
if crop_unexpected:
|
| 567 |
-
logger.info(f"π Crop model unexpected keys: {crop_unexpected[:5]}...")
|
| 568 |
-
|
| 569 |
-
logger.info("π Loading disease model state_dict...")
|
| 570 |
-
disease_result = disease_model.load_state_dict(disease_model_data, strict=False)
|
| 571 |
-
disease_missing = list(disease_result.missing_keys) if hasattr(disease_result, 'missing_keys') else []
|
| 572 |
-
disease_unexpected = list(disease_result.unexpected_keys) if hasattr(disease_result, 'unexpected_keys') else []
|
| 573 |
-
|
| 574 |
-
# Filter out non-critical missing keys
|
| 575 |
-
non_critical_patterns = ['num_batches_tracked', 'running_mean', 'running_var']
|
| 576 |
-
|
| 577 |
-
def filter_critical_keys(missing_keys):
|
| 578 |
-
if not missing_keys:
|
| 579 |
-
return []
|
| 580 |
-
return [k for k in missing_keys if not any(pattern in k for pattern in non_critical_patterns)]
|
| 581 |
-
|
| 582 |
-
crop_critical = filter_critical_keys(crop_missing)
|
| 583 |
-
disease_critical = filter_critical_keys(disease_missing)
|
| 584 |
-
|
| 585 |
-
# Log missing keys
|
| 586 |
-
if crop_missing:
|
| 587 |
-
logger.info(f"π Crop model missing keys: {len(crop_missing)} (filtering non-critical...)")
|
| 588 |
-
if disease_missing:
|
| 589 |
-
logger.info(f"π Disease model missing keys: {len(disease_missing)} (filtering non-critical...)")
|
| 590 |
-
|
| 591 |
-
# AUTO-FIX: Rebuild classifier if critical keys are missing
|
| 592 |
-
def fix_missing_classifier(model, missing_keys, model_name, num_classes):
|
| 593 |
-
"""Automatically fix missing classifier keys by rebuilding the layer"""
|
| 594 |
-
classifier_missing = [k for k in missing_keys if 'classifier' in k]
|
| 595 |
-
|
| 596 |
-
if classifier_missing:
|
| 597 |
-
logger.warning(f"β οΈ {model_name} missing classifier keys: {classifier_missing}")
|
| 598 |
-
logger.info(f"π§ Auto-fixing: Rebuilding classifier for {num_classes} classes...")
|
| 599 |
-
|
| 600 |
-
# Rebuild classifier matching checkpoint architecture EXACTLY
|
| 601 |
-
# Checkpoint has Linear layers at indices 3 and 6
|
| 602 |
-
model.classifier = torch.nn.Sequential(
|
| 603 |
-
torch.nn.Dropout(0.5), # classifier.0 (no weights)
|
| 604 |
-
torch.nn.ReLU(inplace=True), # classifier.1 (no weights)
|
| 605 |
-
torch.nn.Dropout(0.5), # classifier.2 (no weights)
|
| 606 |
-
torch.nn.Linear(25088, 1024), # classifier.3 - HAS WEIGHTS β
|
| 607 |
-
torch.nn.ReLU(inplace=True), # classifier.4 (no weights)
|
| 608 |
-
torch.nn.Dropout(0.5), # classifier.5 (no weights)
|
| 609 |
-
torch.nn.Linear(1024, num_classes) # classifier.6 - HAS WEIGHTS β
|
| 610 |
-
)
|
| 611 |
|
| 612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
|
| 614 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
try:
|
| 616 |
-
model.load_state_dict(
|
| 617 |
-
logger.info(f"β
{model_name}
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
crop_model = fix_missing_classifier(crop_model, crop_critical, "Crop model", num_crop_classes)
|
| 626 |
-
# Re-check for remaining critical keys
|
| 627 |
-
crop_critical_remaining = [k for k in crop_critical if 'classifier' not in k]
|
| 628 |
-
if crop_critical_remaining:
|
| 629 |
-
logger.warning(f"β οΈ Crop model still missing keys: {crop_critical_remaining}")
|
| 630 |
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
# Re-check for remaining critical keys
|
| 634 |
-
disease_critical_remaining = [k for k in disease_critical if 'classifier' not in k]
|
| 635 |
-
if disease_critical_remaining:
|
| 636 |
-
logger.warning(f"β οΈ Disease model still missing keys: {disease_critical_remaining}")
|
| 637 |
|
| 638 |
-
# Move
|
| 639 |
vision_dtype = torch.float16 if vision_device == "cuda" else torch.float32
|
| 640 |
-
logger.info(f"π Vision models dtype: {vision_dtype}")
|
| 641 |
-
|
| 642 |
crop_model.to(vision_device, dtype=vision_dtype)
|
| 643 |
disease_model.to(vision_device, dtype=vision_dtype)
|
| 644 |
|
| 645 |
-
# Set to evaluation mode
|
| 646 |
crop_model.eval()
|
| 647 |
disease_model.eval()
|
| 648 |
-
|
| 649 |
-
logger.info(f"β
Crop model moved to {vision_device} ({vision_dtype})")
|
| 650 |
-
logger.info(f"β
Disease model moved to {vision_device} ({vision_dtype})")
|
| 651 |
-
|
| 652 |
-
# Log final classifier shapes
|
| 653 |
-
crop_classifier_shape = list(crop_model.classifier[-1].weight.shape)
|
| 654 |
-
disease_classifier_shape = list(disease_model.classifier[-1].weight.shape)
|
| 655 |
-
logger.info(f"π Crop classifier output shape: {crop_classifier_shape}")
|
| 656 |
-
logger.info(f"π Disease classifier output shape: {disease_classifier_shape}")
|
| 657 |
-
|
| 658 |
else:
|
| 659 |
-
# If they
|
| 660 |
crop_model = crop_model_data
|
| 661 |
disease_model = disease_model_data
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
|
|
|
|
|
|
| 666 |
|
| 667 |
-
logger.info(f"β
Vision models loaded successfully on {vision_device}!")
|
| 668 |
-
logger.info(f"π― Crop: {len(crop_classes.get('classes', []))} classes")
|
| 669 |
-
logger.info(f"π― Disease: {len(disease_classes.get('classes', []))} classes")
|
| 670 |
-
|
| 671 |
except Exception as e:
|
| 672 |
logger.error(f"β Failed to load vision models: {e}")
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
crop_model_data = torch.load(crop_path / "best_model.pth", map_location='cpu')
|
| 682 |
-
disease_model_data = torch.load(disease_path / "best_model.pth", map_location='cpu')
|
| 683 |
-
|
| 684 |
-
crop_model = CustomEfficientNet(len(crop_classes['classes']))
|
| 685 |
-
disease_model = CustomEfficientNet(len(disease_classes['classes']))
|
| 686 |
-
|
| 687 |
-
crop_model.load_state_dict(crop_model_data, strict=False)
|
| 688 |
-
disease_model.load_state_dict(disease_model_data, strict=False)
|
| 689 |
-
|
| 690 |
-
crop_model.to(vision_device, dtype=torch.float32)
|
| 691 |
-
disease_model.to(vision_device, dtype=torch.float32)
|
| 692 |
-
crop_model.eval()
|
| 693 |
-
disease_model.eval()
|
| 694 |
-
|
| 695 |
-
logger.info(f"β
Vision models loaded on CPU fallback")
|
| 696 |
-
except Exception as fallback_error:
|
| 697 |
-
logger.error(f"β CPU fallback also failed: {fallback_error}")
|
| 698 |
-
import traceback
|
| 699 |
-
logger.error(f"Full error: {traceback.format_exc()}")
|
| 700 |
-
# Create dummy models for testing
|
| 701 |
-
crop_model = None
|
| 702 |
-
disease_model = None
|
| 703 |
-
crop_classes = {'classes': ['tomato', 'potato', 'wheat', 'rice']}
|
| 704 |
-
disease_classes = {'classes': ['healthy', 'early_blight', 'late_blight', 'leaf_spot']}
|
| 705 |
-
else:
|
| 706 |
-
import traceback
|
| 707 |
-
logger.error(f"Full error: {traceback.format_exc()}")
|
| 708 |
-
# Create dummy models for testing
|
| 709 |
-
crop_model = None
|
| 710 |
-
disease_model = None
|
| 711 |
-
crop_classes = {'classes': ['tomato', 'potato', 'wheat', 'rice']}
|
| 712 |
-
disease_classes = {'classes': ['healthy', 'early_blight', 'late_blight', 'leaf_spot']}
|
| 713 |
-
|
| 714 |
# Define image transforms - match training resolution
|
| 715 |
# 100352 / 512 = 196, so sqrt(196) = 14, meaning 14x14 feature map
|
| 716 |
# This suggests input should be smaller to get 25088 features
|
|
@@ -752,7 +604,8 @@ async def load_all_models():
|
|
| 752 |
|
| 753 |
# Load model with joblib
|
| 754 |
market_model = joblib.load(model_path)
|
| 755 |
-
|
|
|
|
| 756 |
|
| 757 |
# Load encoders with joblib
|
| 758 |
encoders = joblib.load(encoders_path)
|
|
@@ -1355,7 +1208,7 @@ async def agricultural_chat(
|
|
| 1355 |
raise HTTPException(status_code=500, detail=str(e))
|
| 1356 |
|
| 1357 |
@app.post("/image-diagnosis")
|
| 1358 |
-
async def image_diagnosis(image_file: UploadFile = File(...)):
|
| 1359 |
"""Diagnose crop diseases from images"""
|
| 1360 |
try:
|
| 1361 |
if 'vision' not in models or models['vision'] is None:
|
|
@@ -1412,12 +1265,27 @@ async def image_diagnosis(image_file: UploadFile = File(...)):
|
|
| 1412 |
|
| 1413 |
with torch.no_grad():
|
| 1414 |
outputs = disease_model(image_tensor)
|
| 1415 |
-
|
| 1416 |
-
|
| 1417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1418 |
|
| 1419 |
-
disease_name = disease_classes['classes'][predicted_class]
|
| 1420 |
logger.info(f"Disease prediction: {disease_name} (confidence: {confidence:.3f})")
|
|
|
|
| 1421 |
except Exception as model_error:
|
| 1422 |
logger.error(f"Model inference error: {model_error}")
|
| 1423 |
# Fallback when model inference fails
|
|
@@ -1431,32 +1299,58 @@ async def image_diagnosis(image_file: UploadFile = File(...)):
|
|
| 1431 |
confidence = random.uniform(0.7, 0.95)
|
| 1432 |
|
| 1433 |
# Generate treatment recommendations
|
| 1434 |
-
|
| 1435 |
-
|
| 1436 |
-
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
|
| 1440 |
-
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
]
|
| 1447 |
-
}
|
| 1448 |
|
| 1449 |
-
|
|
|
|
| 1450 |
"Consult agricultural expert",
|
| 1451 |
"Apply appropriate fungicide",
|
| 1452 |
"Maintain proper plant hygiene"
|
| 1453 |
-
]
|
| 1454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1455 |
return {
|
| 1456 |
"disease": disease_name,
|
| 1457 |
"confidence": round(confidence * 100, 2),
|
| 1458 |
"treatment": treatment,
|
| 1459 |
-
"cause":
|
|
|
|
|
|
|
| 1460 |
}
|
| 1461 |
|
| 1462 |
except Exception as e:
|
|
@@ -1691,5 +1585,5 @@ async def text_to_speech(
|
|
| 1691 |
|
| 1692 |
if __name__ == "__main__":
|
| 1693 |
import uvicorn
|
| 1694 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
| 1695 |
|
|
|
|
| 494 |
|
| 495 |
# Analyze the structure to understand the architecture
|
| 496 |
if isinstance(crop_model_data, dict):
|
| 497 |
+
# Use standard EfficientNet-B0 from torchvision to match training
|
| 498 |
+
from torchvision import models as tv_models
|
| 499 |
|
| 500 |
+
def load_efficientnet(state_dict, num_classes, model_name):
|
| 501 |
+
try:
|
| 502 |
+
logger.info(f"ποΈ Building EfficientNet-B0 for {model_name}...")
|
| 503 |
+
# 1. Init standard model
|
| 504 |
+
model = tv_models.efficientnet_b0(weights=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
+
# 2. Modify classifier to match num_classes (1280 -> num_classes)
|
| 507 |
+
# EfficientNet-B0 classifier is: Sequential(Dropout, Linear(1280, 1000))
|
| 508 |
+
# We need to change the Linear layer at index 1
|
| 509 |
+
in_features = model.classifier[1].in_features
|
| 510 |
+
model.classifier[1] = torch.nn.Linear(in_features, num_classes)
|
| 511 |
|
| 512 |
+
# 3. Load weights
|
| 513 |
+
msg = model.load_state_dict(state_dict, strict=True)
|
| 514 |
+
logger.info(f"β
{model_name} loaded successfully (Strict=True)")
|
| 515 |
+
return model
|
| 516 |
+
except Exception as e:
|
| 517 |
+
logger.warning(f"β οΈ Strict loading failed for {model_name}: {e}")
|
| 518 |
+
logger.info("π Retrying with strict=False...")
|
| 519 |
try:
|
| 520 |
+
model.load_state_dict(state_dict, strict=False)
|
| 521 |
+
logger.info(f"β
{model_name} loaded (Strict=False)")
|
| 522 |
+
return model
|
| 523 |
+
except Exception as e2:
|
| 524 |
+
logger.error(f"β Failed to load {model_name}: {e2}")
|
| 525 |
+
raise e2
|
| 526 |
+
|
| 527 |
+
# Create and load models
|
| 528 |
+
num_crop_classes = len(crop_classes['classes'])
|
| 529 |
+
num_disease_classes = len(disease_classes['classes'])
|
| 530 |
|
| 531 |
+
logger.info(f"π― Loading Crop Model ({num_crop_classes} classes)...")
|
| 532 |
+
crop_model = load_efficientnet(crop_model_data, num_crop_classes, "Crop Model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
+
logger.info(f"π― Loading Disease Model ({num_disease_classes} classes)...")
|
| 535 |
+
disease_model = load_efficientnet(disease_model_data, num_disease_classes, "Disease Model")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
|
| 537 |
+
# Move to device
|
| 538 |
vision_dtype = torch.float16 if vision_device == "cuda" else torch.float32
|
|
|
|
|
|
|
| 539 |
crop_model.to(vision_device, dtype=vision_dtype)
|
| 540 |
disease_model.to(vision_device, dtype=vision_dtype)
|
| 541 |
|
|
|
|
| 542 |
crop_model.eval()
|
| 543 |
disease_model.eval()
|
| 544 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
else:
|
| 546 |
+
# If they are already model objects (legacy support)
|
| 547 |
crop_model = crop_model_data
|
| 548 |
disease_model = disease_model_data
|
| 549 |
+
crop_model.to(vision_device)
|
| 550 |
+
disease_model.to(vision_device)
|
| 551 |
+
crop_model.eval()
|
| 552 |
+
disease_model.eval()
|
| 553 |
+
|
| 554 |
+
logger.info(f"β
Vision models loaded successfully on {vision_device}!")
|
| 555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
except Exception as e:
|
| 557 |
logger.error(f"β Failed to load vision models: {e}")
|
| 558 |
+
import traceback
|
| 559 |
+
logger.error(f"Full error: {traceback.format_exc()}")
|
| 560 |
+
# Create dummy models for testing
|
| 561 |
+
crop_model = None
|
| 562 |
+
disease_model = None
|
| 563 |
+
crop_classes = {'classes': ['tomato', 'potato', 'wheat', 'rice']}
|
| 564 |
+
disease_classes = {'classes': ['healthy', 'early_blight', 'late_blight', 'leaf_spot']}
|
| 565 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
# Define image transforms - match training resolution
|
| 567 |
# 100352 / 512 = 196, so sqrt(196) = 14, meaning 14x14 feature map
|
| 568 |
# This suggests input should be smaller to get 25088 features
|
|
|
|
| 604 |
|
| 605 |
# Load model with joblib
|
| 606 |
market_model = joblib.load(model_path)
|
| 607 |
+
from pathlib import Path
|
| 608 |
+
logger.info(f"Loaded {Path(model_path).name} with joblib")
|
| 609 |
|
| 610 |
# Load encoders with joblib
|
| 611 |
encoders = joblib.load(encoders_path)
|
|
|
|
| 1208 |
raise HTTPException(status_code=500, detail=str(e))
|
| 1209 |
|
| 1210 |
@app.post("/image-diagnosis")
|
| 1211 |
+
async def image_diagnosis(image_file: UploadFile = File(...), language: str = Form("en")):
|
| 1212 |
"""Diagnose crop diseases from images"""
|
| 1213 |
try:
|
| 1214 |
if 'vision' not in models or models['vision'] is None:
|
|
|
|
| 1265 |
|
| 1266 |
with torch.no_grad():
|
| 1267 |
outputs = disease_model(image_tensor)
|
| 1268 |
+
# Get Top 3
|
| 1269 |
+
probs = torch.nn.functional.softmax(outputs, dim=1)
|
| 1270 |
+
top3_prob, top3_idx = torch.topk(probs, min(3, len(disease_classes['classes']))) # Ensure we don't ask for more than available classes
|
| 1271 |
+
|
| 1272 |
+
# Primary prediction (Top 1)
|
| 1273 |
+
top1_idx = top3_idx[0][0].item()
|
| 1274 |
+
confidence = top3_prob[0][0].item()
|
| 1275 |
+
disease_name = disease_classes['classes'][top1_idx]
|
| 1276 |
+
|
| 1277 |
+
# Top 3 List
|
| 1278 |
+
predictions = []
|
| 1279 |
+
for i in range(top3_idx.shape[1]):
|
| 1280 |
+
idx = top3_idx[0][i].item()
|
| 1281 |
+
prob = top3_prob[0][i].item()
|
| 1282 |
+
predictions.append({
|
| 1283 |
+
"disease": disease_classes['classes'][idx],
|
| 1284 |
+
"confidence": round(prob * 100, 2)
|
| 1285 |
+
})
|
| 1286 |
|
|
|
|
| 1287 |
logger.info(f"Disease prediction: {disease_name} (confidence: {confidence:.3f})")
|
| 1288 |
+
logger.info(f"Top 3 predictions: {predictions}")
|
| 1289 |
except Exception as model_error:
|
| 1290 |
logger.error(f"Model inference error: {model_error}")
|
| 1291 |
# Fallback when model inference fails
|
|
|
|
| 1299 |
confidence = random.uniform(0.7, 0.95)
|
| 1300 |
|
| 1301 |
# Generate treatment recommendations
|
| 1302 |
+
try:
|
| 1303 |
+
import json
|
| 1304 |
+
info_path = Path(__file__).parent / "disease_info.json"
|
| 1305 |
+
if info_path.exists():
|
| 1306 |
+
with open(info_path, 'r') as f:
|
| 1307 |
+
disease_info = json.load(f)
|
| 1308 |
+
else:
|
| 1309 |
+
disease_info = {}
|
| 1310 |
+
except Exception:
|
| 1311 |
+
disease_info = {}
|
| 1312 |
+
|
| 1313 |
+
info = disease_info.get(disease_name, {})
|
|
|
|
|
|
|
| 1314 |
|
| 1315 |
+
# Default fallback if disease not in JSON
|
| 1316 |
+
default_treatment = [
|
| 1317 |
"Consult agricultural expert",
|
| 1318 |
"Apply appropriate fungicide",
|
| 1319 |
"Maintain proper plant hygiene"
|
| 1320 |
+
]
|
| 1321 |
|
| 1322 |
+
treatment = info.get("treatment", default_treatment)
|
| 1323 |
+
cause = info.get("cause", "Fungal or bacterial infection caused by environmental conditions")
|
| 1324 |
+
prevention = info.get("prevention", "Use resistant varieties and practice crop rotation")
|
| 1325 |
+
|
| 1326 |
+
# --- TRANSLATION LOGIC ---
|
| 1327 |
+
if language != "en":
|
| 1328 |
+
try:
|
| 1329 |
+
from deep_translator import GoogleTranslator
|
| 1330 |
+
translator = GoogleTranslator(source='auto', target=language)
|
| 1331 |
+
|
| 1332 |
+
# Translate Cause
|
| 1333 |
+
cause = translator.translate(cause)
|
| 1334 |
+
|
| 1335 |
+
# Translate Prevention
|
| 1336 |
+
prevention = translator.translate(prevention)
|
| 1337 |
+
|
| 1338 |
+
# Translate Treatment List
|
| 1339 |
+
translated_treatments = []
|
| 1340 |
+
for t in treatment:
|
| 1341 |
+
translated_treatments.append(translator.translate(t))
|
| 1342 |
+
treatment = translated_treatments
|
| 1343 |
+
|
| 1344 |
+
except Exception as trans_e:
|
| 1345 |
+
logger.error(f"Translation failed: {trans_e}")
|
| 1346 |
+
|
| 1347 |
return {
|
| 1348 |
"disease": disease_name,
|
| 1349 |
"confidence": round(confidence * 100, 2),
|
| 1350 |
"treatment": treatment,
|
| 1351 |
+
"cause": cause,
|
| 1352 |
+
"prevention": prevention,
|
| 1353 |
+
"top_3_predictions": predictions
|
| 1354 |
}
|
| 1355 |
|
| 1356 |
except Exception as e:
|
|
|
|
| 1585 |
|
| 1586 |
if __name__ == "__main__":
|
| 1587 |
import uvicorn
|
| 1588 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 1589 |
|