Sensei13k's picture
Upload 8 files
e4f8fe4 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
import cv2
import numpy as np
import os
from ultralytics import YOLO
from PIL import Image
import tempfile
import tensorflow as tf
import easyocr
from HV_PD import draw_obb
from analog_test import crop_region, calculate_needle_corner_ratio_approximation, get_center_point
app = FastAPI()
try:
analog_box = YOLO("Models/analog_box_v2.pt")
analog_reading = YOLO("Models/analog_reading_v2.pt")
remaining_test_model = YOLO("Models/HV_PD_model.pt")
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
reader = easyocr.Reader(['en'])
from res_temp_N2N import order_points
from Lenet_res import YoloLeNetOCR
res_temp_box_model = 'Models/res_temp_box_v3.pt'
temp_yolo_model = 'Models/temp_detect_v4.pt'
temp_cnn_model = 'Models/Lenet_temp_v4.h5'
res_yolo_model = 'Models/res_detect_v4.pt'
res_cnn_model = 'Models/lenet_res_v4.h5'
class_names = ['res', 'temp']
# TwoStageOCR class from res_temp_N2N.py
class TwoStageOCR:
def __init__(
self,
box_model_path: str,
yolo_model_path: str,
cnn_model_path: str,
image_size=(28, 28),
conf_threshold=0.25
):
from ultralytics import YOLO
import tensorflow as tf
self.box_detector = YOLO(box_model_path)
self.digit_detector = YOLO(yolo_model_path)
self.cnn = tf.keras.models.load_model(cnn_model_path, compile=False)
class_names = ['0','1','2','3','4','5','6','7','8','9','C','dot']
self.inv_map = {i: label for i, label in enumerate(class_names)}
self.image_size = image_size
self.conf_threshold = conf_threshold
def preprocess_crop(self, crop):
import cv2
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
resized = cv2.resize(gray, self.image_size)
normed = resized.astype('float32') / 255.0
return normed.reshape(1, *self.image_size, 1)
def ocr_panel(self, panel):
import numpy as np
res = self.digit_detector.predict(source=panel, verbose=False)[0]
boxes = res.boxes.xyxy.cpu().numpy()
confs = res.boxes.conf.cpu().numpy()
mask = confs >= self.conf_threshold
boxes = boxes[mask]
if boxes.size == 0:
return ""
boxes = boxes[np.argsort(boxes[:, 0])]
digits = []
for x1, y1, x2, y2 in boxes:
c = panel[int(y1):int(y2), int(x1):int(x2)]
inp = self.preprocess_crop(c)
probs = self.cnn.predict(inp, verbose=False)
idx = int(np.argmax(probs, axis=1)[0])
label = self.inv_map[idx]
digits.append(label)
return ''.join(digits)
# Instantiate OCR models (reuse for all requests)
temp_ocr = TwoStageOCR(
box_model_path=res_temp_box_model,
yolo_model_path=temp_yolo_model,
cnn_model_path=temp_cnn_model,
image_size=(28,28),
conf_threshold=0.3
)
res_ocr = YoloLeNetOCR(
yolo_model_path=res_yolo_model,
lenet_model_path=res_cnn_model,
image_size=(28,28),
conf_threshold=0.5
)
def process_res_temp(file_bytes):
from res_temp_N2N import order_points
try:
image_cv = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
if image_cv is None:
raise HTTPException(status_code=400, detail="Invalid image data for CONDUCTOR_RESISTANCE_TEST")
# OBB detection
res_panels = temp_ocr.box_detector.predict(source=image_cv, verbose=False)[0]
if not hasattr(res_panels, 'obb') or res_panels.obb is None:
raise HTTPException(status_code=400, detail="No panels detected in image")
polys = res_panels.obb.xyxyxyxy.cpu().numpy()
confs = res_panels.obb.conf.cpu().numpy()
class_ids = res_panels.obb.cls.cpu().numpy().astype(int)
class_names = temp_ocr.box_detector.model.names
results = []
confidence_scores = []
for poly, conf, cls_id in zip(polys, confs, class_ids):
if conf < 0.3:
continue
pts = poly.reshape(4, 2).astype(np.float32)
rect = order_points(pts)
(tl, tr, br, bl) = rect
widthA = np.linalg.norm(br - bl)
widthB = np.linalg.norm(tr - tl)
maxW = int(max(widthA, widthB))
heightA = np.linalg.norm(tr - br)
heightB = np.linalg.norm(tl - bl)
maxH = int(max(heightA, heightB))
dst = np.array([
[0, 0],
[maxW - 1, 0],
[maxW - 1, maxH - 1],
[0, maxH - 1]
], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
crop = cv2.warpPerspective(image_cv, M, (maxW, maxH))
class_name = class_names[cls_id]
if class_name == 'temp':
raw = temp_ocr.ocr_panel(crop)
temp_digits = raw.replace('C', '')
if len(temp_digits) > 1:
formatted = temp_digits[:-1] + '.' + temp_digits[-1] + '°C'
else:
formatted = temp_digits + '°C'
results.append({
"keyName": "temp",
"keyValue": formatted,
"actualValue": formatted,
"confidenceScore": round(float(conf), 2)
})
confidence_scores.append(float(conf))
elif class_name == 'res':
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
cv2.imwrite(tmp.name, crop)
tmp_path = tmp.name
raw = res_ocr.ocr_image(tmp_path)
os.remove(tmp_path)
raw = raw.replace("dot", ".")
results.append({
"keyName": "res",
"keyValue": raw,
"actualValue": raw,
"confidenceScore": round(float(conf), 2)
})
confidence_scores.append(float(conf))
if not results:
raise HTTPException(status_code=400, detail="No resistance or temperature panels detected.")
overall_confidence = round(sum(confidence_scores) / len(confidence_scores), 2) if confidence_scores else 0.75
return {
"ocs": overall_confidence,
"extractions": results
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing CONDUCTOR_RESISTANCE_TEST: {str(e)}")
def process_remaining_test(file_bytes, expected_classes):
try:
image_cv = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
if image_cv is None:
raise HTTPException(status_code=400, detail="Invalid image data for processing")
# Run inference using the remaining tests model
results = remaining_test_model(image_cv)
extracted_data = {}
confidence_scores = {}
# Process results and extract text from detected regions
for r in results:
if r.obb is not None:
# Get confidence scores from the detections
confidences = r.obb.conf.cpu().numpy() if hasattr(r.obb, 'conf') else None
# Use the draw_obb function from Remaining_test.py to extract text
_, extracted_texts = draw_obb(image_cv.copy(), r.obb)
# Match the extracted texts with their class names
for i, class_id in enumerate(r.obb.cls.cpu().numpy()):
class_name = r.names[int(class_id)]
# Only process classes that we expect for this test type
if class_name in expected_classes and i < len(extracted_texts) and extracted_texts[i]:
# Store the detected text with its class name
extracted_data[class_name] = extracted_texts[i]
# Store confidence score if available
if confidences is not None and i < len(confidences):
confidence_scores[class_name] = float(confidences[i])
else:
confidence_scores[class_name] = 0.75
# Calculate overall confidence score (average of individual scores)
overall_confidence = 0.0
if confidence_scores:
overall_confidence = sum(confidence_scores.values()) / len(confidence_scores)
else:
overall_confidence = 0.75 # Default if no scores available
# Round overall confidence to 2 decimal places
overall_confidence = round(overall_confidence, 2)
# Format response with individual rounded confidence scores
kv_list = []
for k, v in extracted_data.items():
conf = round(confidence_scores.get(k, 0.75), 2)
kv_list.append({
"keyName": k,
"keyValue": v,
"actualValue": v,
"confidenceScore": conf
})
# Determine test type based on expected classes
test_type = "extractions"
# If no data was extracted
if not kv_list:
raise HTTPException(status_code=400, detail=f"No data extracted for the expected classes: {expected_classes}")
return {"ocs": overall_confidence, test_type: kv_list}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing test data: {str(e)}")
def process_dc_test(file_bytes):
"""
Implements the DC_TEST pipeline using functions from analog_test.py.
It decodes the image, ensures consistent color format, detects and crops the meter
region using the analog_box model, and then uses the analog_reading model along with
functions from analog_test.py to compute the meter reading.
"""
try:
# Decode file bytes into a CV image (BGR)
image_cv = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
if image_cv is None:
raise HTTPException(status_code=400, detail="Invalid image data for DC_TEST")
results = analog_box(image_cv)
cropped_meter = None
for r in results:
if hasattr(r, "obb") and r.obb is not None:
from analog_test import crop_region
cropped_meter = crop_region(image_cv, r.obb)
if cropped_meter is not None:
break
if cropped_meter is None:
raise HTTPException(status_code=400, detail="No analog meter detected in image")
meter_results = analog_reading(cropped_meter)
needle_corners = None
number_positions = []
needle_confidence = 0
number_confidences = []
for r in meter_results:
if hasattr(r, "obb") and r.obb is not None:
boxes = r.obb.xyxyxyxy.cpu().numpy()
classes = r.obb.cls.cpu().numpy()
confidences = r.obb.conf.cpu().numpy() # Get confidence scores
for box, class_id, conf in zip(boxes, classes, confidences):
class_name = r.names[int(class_id)]
from analog_test import get_center_point
center = get_center_point(box)
if class_name.lower() == "needle":
needle_corners = box.reshape(4, 2)
needle_confidence = float(conf)
elif (class_name.isdigit() or
class_name in ["0", "5", "10", "15", "20", "25", "30"] or
class_name.lower() == "numbers"):
number_positions.append((0, center))
number_confidences.append(float(conf))
if needle_corners is not None and number_positions:
from analog_test import calculate_needle_corner_ratio_approximation
reading, method = calculate_needle_corner_ratio_approximation(needle_corners, number_positions)
overall_confidence = (2 * needle_confidence + sum(number_confidences)) / (2 + len(number_confidences))
overall_confidence = round(overall_confidence, 2)
reading = round(float(reading), 2)
list = [{
"keyName": "MeterReading",
"keyValue": str(reading),
"actualValue": str(reading),
"confidenceScore": overall_confidence,
}]
return {
"ocs": overall_confidence,
"extractions": list
}
else:
raise HTTPException(status_code=400, detail="Could not detect needle or number positions in meter")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing DC_TEST: {str(e)}")
@app.post("/detect/")
async def detect(file: UploadFile = File(...), test_type: str = Form(...)):
file_bytes = await file.read()
if test_type == "CONDUCTOR_RESISTANCE_TEST":
return process_res_temp(file_bytes)
elif test_type == "DC_TEST":
return process_dc_test(file_bytes)
elif test_type == "PARTIAL_DISCHARGE_TEST":
return process_remaining_test(file_bytes, expected_classes=["q(IEC) value", "qCValue"])
elif test_type == "HIGH_VOLTAGE_TEST":
return process_remaining_test(file_bytes, expected_classes=["kV", "TimeLeft", "UVolt"])
else:
raise HTTPException(status_code=400, detail="Invalid test_type. Choose 'CONDUCTOR_RESISTANCE_TEST', 'DC_TEST', 'PARTIAL_DISCHARGE_TEST', or 'HIGH_VOLTAGE_TEST'")
@app.get("/")
def health_check():
return {"status": "healthy", "version": "v4.1"}