dcrm-analysis-api / flask_app.py
sikeaditya's picture
Update flask_app.py
15f8f6c verified
# flask_app.py
"""
Flask API for DCRM (Dynamic Contact Resistance Measurement) Analysis
Provides endpoints for uploading DCRM graph images and getting AI-powered analysis.
"""
from flask import Flask, request, jsonify
from flask_cors import CORS
import cv2
import numpy as np
import os
import json
import re
import tempfile
import base64
from werkzeug.utils import secure_filename
# Import DCRM modules
from dcrm.image_processing import process_uploaded_image
from dcrm.llm import ask_llm_for_breakage, analyze_health_with_llm
from dcrm.zone_analysis import ZoneAnalyzer
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Configuration
app.config["MAX_CONTENT_LENGTH"] = 16 * 1024 * 1024 # 16MB max file size
ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg"}
# Default processing parameters
DEFAULT_SAT_FACTOR = 3.0
DEFAULT_GAP_SIZE = 1
DEFAULT_NOISE_THRESHOLD = 100
DEFAULT_TOTAL_DURATION = 400
DEFAULT_CROP_OPTION = True
DEFAULT_MODEL_NAME = "gemini-2.5-flash"
def allowed_file(filename):
"""Check if file extension is allowed"""
return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS
def safe_parse_llm_json(llm_response):
"""Robustly extracts JSON from LLM response, handling markdown and plain text."""
try:
# Try finding markdown block first
json_match = re.search(r"```json\s*(\{.*?\})\s*```", llm_response, re.DOTALL)
if json_match:
return json.loads(json_match.group(1))
# Try finding just a JSON object structure
json_match_loose = re.search(r"(\{.*\})", llm_response, re.DOTALL)
if json_match_loose:
return json.loads(json_match_loose.group(1))
# Try loading the whole string
return json.loads(llm_response)
except:
return None
def convert_numpy_types(obj):
"""Convert numpy types to Python native types for JSON serialization"""
if isinstance(obj, dict):
return {key: convert_numpy_types(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(item) for item in obj]
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif hasattr(obj, "item"): # For numpy scalar types
return obj.item()
else:
return obj
def image_to_base64(img_array):
"""Convert a numpy image array to base64 string"""
if img_array is None:
return None
# Ensure it's in BGR format for encoding
if len(img_array.shape) == 3 and img_array.shape[2] == 3:
# Convert RGB to BGR if needed (OpenCV expects BGR)
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
img_bgr = img_array
_, buffer = cv2.imencode(".png", img_bgr)
return base64.b64encode(buffer).decode("utf-8")
@app.route("/", methods=["GET"])
def index():
"""Root endpoint with API info"""
return jsonify(
{
"service": "DCRM Analysis API",
"version": "1.0.0",
"endpoints": {
"GET /health": "Health check",
"POST /analyze": "Full DCRM analysis with AI",
"POST /extract-curves": "Extract curves only (no AI)",
},
"docs": "https://github.com/YOUR_REPO/README.md",
}
)
@app.route("/health", methods=["GET"])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy", "service": "DCRM Analysis API"})
@app.route("/analyze", methods=["POST"])
def analyze_image():
"""
Main endpoint for DCRM image analysis.
Expects:
- image: File upload (multipart/form-data) or base64 encoded image
- api_key: Gemini API key (required)
- sat_factor: Saturation boost factor (optional, default: 3.0)
- gap_size: Gap fill size (optional, default: 1)
- noise_threshold: Minimum object area (optional, default: 100)
- total_duration: Graph duration in ms (optional, default: 400)
- crop_option: Auto-crop option (optional, default: true)
- analysis_method: "image" or "csv" (optional, default: "image")
Returns:
JSON response with analysis results
"""
try:
# Get API key
api_key = (
request.form.get("api_key") or request.json.get("api_key")
if request.is_json
else request.form.get("api_key")
)
if not api_key:
# Try to get from environment
api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get(
"GOOGLE_API_KEY"
)
if not api_key:
return (
jsonify(
{
"error": "API key is required. Provide 'api_key' in the request or set GEMINI_API_KEY environment variable."
}
),
400,
)
# Get image data
file_bytes = None
# Check for file upload
if "image" in request.files:
file = request.files["image"]
if file.filename == "":
return jsonify({"error": "No file selected"}), 400
if not allowed_file(file.filename):
return (
jsonify(
{
"error": f"Invalid file type. Allowed: {', '.join(ALLOWED_EXTENSIONS)}"
}
),
400,
)
file_bytes = file.read()
# Check for base64 image
elif request.is_json and "image_base64" in request.json:
try:
file_bytes = base64.b64decode(request.json["image_base64"])
except Exception as e:
return jsonify({"error": f"Invalid base64 image: {str(e)}"}), 400
else:
return (
jsonify(
{
"error": "No image provided. Use 'image' file upload or 'image_base64' in JSON."
}
),
400,
)
# Get processing parameters
if request.is_json:
params = request.json
else:
params = request.form
sat_factor = float(params.get("sat_factor", DEFAULT_SAT_FACTOR))
gap_size = int(params.get("gap_size", DEFAULT_GAP_SIZE))
noise_threshold = int(params.get("noise_threshold", DEFAULT_NOISE_THRESHOLD))
total_duration = int(params.get("total_duration", DEFAULT_TOTAL_DURATION))
crop_option = str(params.get("crop_option", "true")).lower() == "true"
analysis_method = params.get("analysis_method", "image")
model_name = params.get("model_name", DEFAULT_MODEL_NAME)
include_debug_images = (
str(params.get("include_debug_images", "false")).lower() == "true"
)
# Step 1: Extract curves from image
df_result, debug_images, bounds, error_msg, _ = process_uploaded_image(
file_bytes,
sat_factor,
gap_size,
noise_threshold,
crop_option,
total_duration,
)
if error_msg:
return (
jsonify(
{
"error": f"Curve extraction failed: {error_msg}",
"stage": "extraction",
}
),
400,
)
# Step 2: Get LLM segmentation
cropped_bytes = None
if bounds:
try:
sx, ex = bounds
nparr = np.frombuffer(file_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is not None:
cropped_img = img[:, sx:ex]
is_success, buffer = cv2.imencode(".jpg", cropped_img)
if is_success:
cropped_bytes = buffer.tobytes()
except Exception as e:
pass # Continue without cropped image
df_result, result_json = ask_llm_for_breakage(
df_result, api_key, model_name, image_bytes=cropped_bytes
)
if not result_json or "error" in result_json:
return (
jsonify(
{
"error": "AI segmentation failed",
"details": (
result_json.get("error") if result_json else "Unknown error"
),
"stage": "segmentation",
}
),
400,
)
# Step 3: Perform zone health analysis
zone_analysis = {}
analysis_type = ""
analysis_data = None
executive_lead = None
issues = []
success_expert_image = False
if analysis_method.lower() == "image":
# Image-based analysis
numerical_context = {}
if "Resistance" in df_result.columns:
valid_res = df_result["Resistance"].dropna()
if not valid_res.empty:
numerical_context["min_resistance"] = float(valid_res.min())
numerical_context["median_resistance"] = float(valid_res.median())
img_bytes_for_analysis = cropped_bytes if cropped_bytes else file_bytes
llm_response = analyze_health_with_llm(
img_bytes_for_analysis, api_key, model_name, numerical_context
)
if isinstance(llm_response, dict) and "error" in llm_response:
analysis_type = "Image-Based (Failed) - Fallback to CSV"
success_expert_image = False
else:
analysis_data = safe_parse_llm_json(llm_response)
if analysis_data:
executive_lead = llm_response.split("{")[0].strip()
if "```json" in executive_lead:
executive_lead = executive_lead.replace("```json", "").strip()
issues = analysis_data.get("detected_issues", [])
extracted_score = analysis_data.get("health_score")
status = analysis_data.get("overall_condition", "Unknown")
if extracted_score is None:
if status == "Healthy":
extracted_score = 100
elif status == "Warning":
extracted_score = 60
elif status == "Critical":
extracted_score = 20
else:
extracted_score = 0
zone_analysis = {
"overall_health": {
"status": status,
"overall_score": extracted_score,
"recommendation": analysis_data.get(
"maintenance_recommendation"
),
"total_issues": len(issues),
"critical_issues": [],
}
}
analysis_type = "Expert Image Diagnostic"
success_expert_image = True
else:
analysis_type = "Image-Based (Parse Error) - Fallback to CSV"
success_expert_image = False
# Fallback to CSV analysis
if not success_expert_image:
analyzer = ZoneAnalyzer(df_result, result_json)
zone_analysis = analyzer.analyze_all_zones()
analysis_type = "CSV-Based"
# Prepare response
response_data = {
"success": True,
"analysis_type": analysis_type,
"segmentation": convert_numpy_types(result_json),
"zone_analysis": convert_numpy_types(zone_analysis),
"curve_data": {
"columns": df_result.columns.tolist(),
"data": df_result.to_dict(orient="records"),
"num_points": len(df_result),
},
"processing_params": {
"sat_factor": sat_factor,
"gap_size": gap_size,
"noise_threshold": noise_threshold,
"total_duration": total_duration,
"crop_option": crop_option,
},
}
# Add expert analysis details if available
if analysis_data:
response_data["expert_analysis"] = {
"executive_summary": executive_lead,
"detailed_analysis": convert_numpy_types(analysis_data),
"issues": convert_numpy_types(issues),
}
# Include debug images if requested
if include_debug_images and debug_images:
response_data["debug_images"] = {}
for name, img in debug_images.items():
img_b64 = image_to_base64(img)
if img_b64:
response_data["debug_images"][name] = img_b64
return jsonify(convert_numpy_types(response_data))
except Exception as e:
import traceback
return (
jsonify(
{
"error": f"Internal server error: {str(e)}",
"traceback": traceback.format_exc(),
}
),
500,
)
@app.route("/extract-curves", methods=["POST"])
def extract_curves():
"""
Lightweight endpoint that only extracts curves without LLM analysis.
Useful for quick data extraction without AI processing.
Expects:
- image: File upload (multipart/form-data) or base64 encoded image
- sat_factor, gap_size, noise_threshold, total_duration, crop_option (optional)
Returns:
JSON with extracted curve data
"""
try:
# Get image data
file_bytes = None
if "image" in request.files:
file = request.files["image"]
if file.filename == "":
return jsonify({"error": "No file selected"}), 400
if not allowed_file(file.filename):
return (
jsonify(
{
"error": f"Invalid file type. Allowed: {', '.join(ALLOWED_EXTENSIONS)}"
}
),
400,
)
file_bytes = file.read()
elif request.is_json and "image_base64" in request.json:
try:
file_bytes = base64.b64decode(request.json["image_base64"])
except Exception as e:
return jsonify({"error": f"Invalid base64 image: {str(e)}"}), 400
else:
return jsonify({"error": "No image provided"}), 400
# Get processing parameters
if request.is_json:
params = request.json
else:
params = request.form
sat_factor = float(params.get("sat_factor", DEFAULT_SAT_FACTOR))
gap_size = int(params.get("gap_size", DEFAULT_GAP_SIZE))
noise_threshold = int(params.get("noise_threshold", DEFAULT_NOISE_THRESHOLD))
total_duration = int(params.get("total_duration", DEFAULT_TOTAL_DURATION))
crop_option = str(params.get("crop_option", "true")).lower() == "true"
include_debug_images = (
str(params.get("include_debug_images", "false")).lower() == "true"
)
# Extract curves
df_result, debug_images, bounds, error_msg, _ = process_uploaded_image(
file_bytes,
sat_factor,
gap_size,
noise_threshold,
crop_option,
total_duration,
)
if error_msg:
return jsonify({"error": f"Curve extraction failed: {error_msg}"}), 400
response_data = {
"success": True,
"curve_data": {
"columns": df_result.columns.tolist(),
"data": df_result.to_dict(orient="records"),
"num_points": len(df_result),
},
"bounds": bounds,
"processing_params": {
"sat_factor": sat_factor,
"gap_size": gap_size,
"noise_threshold": noise_threshold,
"total_duration": total_duration,
"crop_option": crop_option,
},
}
if include_debug_images and debug_images:
response_data["debug_images"] = {}
for name, img in debug_images.items():
img_b64 = image_to_base64(img)
if img_b64:
response_data["debug_images"][name] = img_b64
return jsonify(convert_numpy_types(response_data))
except Exception as e:
import traceback
return (
jsonify(
{
"error": f"Internal server error: {str(e)}",
"traceback": traceback.format_exc(),
}
),
500,
)
@app.errorhandler(413)
def too_large(e):
return jsonify({"error": "File too large. Maximum size is 16MB."}), 413
@app.errorhandler(404)
def not_found(e):
return jsonify({"error": "Endpoint not found"}), 404
@app.errorhandler(500)
def internal_error(e):
return jsonify({"error": "Internal server error"}), 500
if __name__ == "__main__":
# Get port from environment or use default (7860 for Hugging Face Spaces)
port = int(os.environ.get("PORT", 7860))
debug = os.environ.get("FLASK_DEBUG", "false").lower() == "true"
print(
f"""
╔══════════════════════════════════════════════════════════════╗
β•‘ DCRM Analysis API - Flask Server β•‘
╠══════════════════════════════════════════════════════════════╣
β•‘ Endpoints: β•‘
β•‘ GET /health - Health check β•‘
β•‘ POST /analyze - Full DCRM analysis with AI β•‘
β•‘ POST /extract-curves - Extract curves only (no AI) β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
)
app.run(host="0.0.0.0", port=port, debug=debug)