Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.responses import HTMLResponse, StreamingResponse | |
| from transformers import pipeline | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import io | |
| import uvicorn | |
| import base64 | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as ReportLabImage | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.lib.enums import TA_CENTER | |
| from reportlab.lib.units import inch | |
| app = FastAPI() | |
| # Load models | |
| def load_models(): | |
| return { | |
| "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), | |
| "KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), | |
| "RöntgenMeister": pipeline("image-classification", | |
| model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") | |
| } | |
| models = load_models() | |
| def translate_label(label): | |
| # Keep translations for internal use if needed, but for the PDF we'll use English | |
| translations = { | |
| "fracture": "Fracture", | |
| "no fracture": "No Fracture", | |
| "normal": "Normal", | |
| "abnormal": "Abnormal", | |
| "F1": "Fracture", # Assuming F1 also means fracture | |
| "NF": "No Fracture" # Assuming NF means no fracture | |
| } | |
| return translations.get(label.lower(), label) | |
| def create_heatmap_overlay(image, box, score): | |
| overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| x1, y1 = box['xmin'], box['ymin'] | |
| x2, y2 = box['xmax'], box['ymax'] | |
| if score > 0.8: | |
| fill_color = (255, 0, 0, 100) | |
| border_color = (255, 0, 0, 255) | |
| elif score > 0.6: | |
| fill_color = (255, 165, 0, 100) | |
| border_color = (255, 165, 0, 255) | |
| else: | |
| fill_color = (255, 255, 0, 100) | |
| border_color = (255, 255, 0, 255) | |
| draw.rectangle([x1, y1, x2, y2], fill=fill_color) | |
| draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2) | |
| return overlay | |
| def draw_boxes(image, predictions): | |
| result_image = image.copy().convert('RGBA') | |
| for pred in predictions: | |
| box = pred['box'] | |
| score = pred['score'] | |
| overlay = create_heatmap_overlay(image, box, score) | |
| result_image = Image.alpha_composite(result_image, overlay) | |
| draw = ImageDraw.Draw(result_image) | |
| temp = 36.5 + (score * 2.5) | |
| # Label in English | |
| label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)" | |
| try: | |
| text_bbox = draw.textbbox((box['xmin'], box['ymin'] - 20), label) | |
| except AttributeError: | |
| font_size = 10 | |
| text_width = len(label) * font_size * 0.6 | |
| text_height = font_size * 1.2 | |
| text_bbox = (box['xmin'], box['ymin'] - text_height, box['xmin'] + text_width, box['ymin']) | |
| draw.rectangle(text_bbox, fill=(0, 0, 0, 180)) | |
| draw.text( | |
| (box['xmin'], box['ymin']-20), | |
| label, | |
| fill=(255, 255, 255, 255) | |
| ) | |
| return result_image | |
| def image_to_base64(image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_str}" | |
| COMMON_STYLES = """ | |
| body { | |
| font-family: system-ui, -apple-system, sans-serif; | |
| background: #f0f2f5; | |
| margin: 0; | |
| padding: 20px; | |
| color: #1a1a1a; | |
| } | |
| ::-webkit-scrollbar { | |
| width: 8px; | |
| height: 8px; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: transparent; | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background-color: rgba(156, 163, 175, 0.5); | |
| border-radius: 4px; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| background: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| } | |
| .button { | |
| background: #2d2d2d; | |
| color: white; | |
| border: none; | |
| padding: 12px 30px; | |
| border-radius: 8px; | |
| cursor: pointer; | |
| font-size: 1.1em; | |
| transition: all 0.3s ease; | |
| position: relative; | |
| } | |
| .button:hover { | |
| background: #404040; | |
| } | |
| @keyframes progress { | |
| 0% { width: 0; } | |
| 100% { width: 100%; } | |
| } | |
| .button-progress { | |
| position: absolute; | |
| bottom: 0; | |
| left: 0; | |
| height: 4px; | |
| background: rgba(255, 255, 255, 0.5); | |
| width: 0; | |
| } | |
| .button:active .button-progress { | |
| animation: progress 2s linear forwards; | |
| } | |
| img { | |
| max-width: 100%; | |
| height: auto; | |
| border-radius: 8px; | |
| } | |
| @keyframes blink { | |
| 0% { opacity: 1; } | |
| 50% { opacity: 0; } | |
| 100% { opacity: 1; } | |
| } | |
| #loading { | |
| display: none; | |
| color: white; | |
| margin-top: 10px; | |
| animation: blink 1s infinite; | |
| text-align: center; | |
| } | |
| """ | |
| async def main(): | |
| content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Fracture Detection</title> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <style> | |
| {COMMON_STYLES} | |
| .upload-section {{ | |
| background: #2d2d2d; | |
| padding: 40px; | |
| border-radius: 12px; | |
| margin: 20px 0; | |
| text-align: center; | |
| border: 2px dashed #404040; | |
| transition: all 0.3s ease; | |
| color: white; | |
| }} | |
| .upload-section:hover {{ | |
| border-color: #555; | |
| }} | |
| input[type="file"] {{ | |
| font-size: 1.1em; | |
| margin: 20px 0; | |
| color: white; | |
| }} | |
| input[type="file"]::file-selector-button {{ | |
| font-size: 1em; | |
| padding: 10px 20px; | |
| border-radius: 8px; | |
| border: 1px solid #404040; | |
| background: #2d2d2d; | |
| color: white; | |
| transition: all 0.3s ease; | |
| cursor: pointer; | |
| }} | |
| input[type="file"]::file-selector-button:hover {{ | |
| background: #404040; | |
| }} | |
| .confidence-slider {{ | |
| width: 100%; | |
| max-width: 300px; | |
| margin: 20px auto; | |
| }} | |
| input[type="range"] {{ | |
| width: 100%; | |
| height: 8px; | |
| border-radius: 4px; | |
| background: #404040; | |
| outline: none; | |
| transition: all 0.3s ease; | |
| -webkit-appearance: none; | |
| }} | |
| input[type="range"]::-webkit-slider-thumb {{ | |
| -webkit-appearance: none; | |
| width: 20px; | |
| height: 20px; | |
| border-radius: 50%; | |
| background: white; | |
| cursor: pointer; | |
| border: none; | |
| }} | |
| .input-field {{ | |
| margin-bottom: 20px; | |
| }} | |
| .input-field label {{ | |
| display: block; | |
| margin-bottom: 5px; | |
| font-size: 1.1em; | |
| }} | |
| .input-field input[type="text"] {{ | |
| width: calc(100% - 20px); | |
| padding: 10px; | |
| border-radius: 5px; | |
| border: 1px solid #ccc; | |
| background: #fff; | |
| color: #1a1a1a; | |
| font-size: 1em; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="upload-section"> | |
| <form action="/analyze" method="post" enctype="multipart/form-data" onsubmit="document.getElementById('loading').style.display = 'block';"> | |
| <div class="input-field"> | |
| <label for="patient_name">Patient Name:</label> | |
| <input type="text" id="patient_name" name="patient_name" required> | |
| </div> | |
| <div> | |
| <input type="file" name="file" accept="image/*" required> | |
| </div> | |
| <div class="confidence-slider"> | |
| <label for="threshold">Confidence Threshold: <span id="thresholdValue">0.60</span></label> | |
| <input type="range" id="threshold" name="threshold" | |
| min="0" max="1" step="0.05" value="0.60" | |
| oninput="document.getElementById('thresholdValue').textContent = parseFloat(this.value).toFixed(2)"> | |
| </div> | |
| <button type="submit" class="button"> | |
| Analyze & Generate PDF | |
| <div class="button-progress"></div> | |
| </button> | |
| <div id="loading">Loading...</div> | |
| </form> | |
| </div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return content | |
| async def analyze_file(patient_name: str = Form(...), file: UploadFile = File(...), threshold: float = Form(0.6)): | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") # Ensure RGB for PDF | |
| predictions_watcher = models["KnochenWächter"](image) | |
| predictions_master = models["RöntgenMeister"](image) | |
| predictions_locator = models["KnochenAuge"](image) | |
| filtered_preds = [p for p in predictions_locator if p['score'] >= threshold] | |
| if filtered_preds: | |
| result_image = draw_boxes(image, filtered_preds) | |
| else: | |
| result_image = image | |
| # Generate PDF | |
| buffer = io.BytesIO() | |
| doc = SimpleDocTemplate(buffer, pagesize=letter) | |
| styles = getSampleStyleSheet() | |
| centered_style = ParagraphStyle( | |
| name='Centered', | |
| parent=styles['Normal'], | |
| alignment=TA_CENTER, | |
| fontSize=12, | |
| leading=14 | |
| ) | |
| heading_style = ParagraphStyle( | |
| name='Heading', | |
| parent=styles['h1'], | |
| alignment=TA_CENTER, | |
| fontSize=24, | |
| spaceAfter=20 | |
| ) | |
| subheading_style = ParagraphStyle( | |
| name='SubHeading', | |
| parent=styles['h2'], | |
| alignment=TA_CENTER, | |
| fontSize=16, | |
| spaceAfter=10 | |
| ) | |
| report_text_style = ParagraphStyle( | |
| name='ReportText', | |
| parent=styles['Normal'], | |
| alignment=TA_CENTER, | |
| fontSize=12, | |
| spaceAfter=5 | |
| ) | |
| story = [] | |
| story.append(Paragraph("<b>Fracture Detection Report</b>", heading_style)) | |
| story.append(Spacer(1, 0.2 * inch)) | |
| story.append(Paragraph(f"<b>Patient Name:</b> {patient_name}", subheading_style)) | |
| story.append(Spacer(1, 0.4 * inch)) | |
| # KnochenWächter results | |
| story.append(Paragraph("<b>KnochenWächter Results:</b>", subheading_style)) | |
| for pred in predictions_watcher: | |
| story.append(Paragraph( | |
| f"{translate_label(pred['label'])}: {pred['score']:.1%}", | |
| report_text_style | |
| )) | |
| story.append(Spacer(1, 0.2 * inch)) | |
| # RöntgenMeister results | |
| story.append(Paragraph("<b>RöntgenMeister Results:</b>", subheading_style)) | |
| for pred in predictions_master: | |
| story.append(Paragraph( | |
| f"{translate_label(pred['label'])}: {pred['score']:.1%}", | |
| report_text_style | |
| )) | |
| story.append(Spacer(1, 0.4 * inch)) | |
| # Analyzed Image | |
| story.append(Paragraph("<b>X-ray Image Analysis:</b>", subheading_style)) | |
| img_buffer = io.BytesIO() | |
| result_image.save(img_buffer, format="PNG") | |
| img_buffer.seek(0) | |
| img_rl = ReportLabImage(img_buffer) | |
| img_width, img_height = img_rl.drawWidth, img_rl.drawHeight | |
| aspect_ratio = img_height / img_width | |
| max_width = 5 * inch | |
| if img_width > max_width: | |
| img_rl.drawWidth = max_width | |
| img_rl.drawHeight = max_width * aspect_ratio | |
| img_rl.hAlign = 'CENTER' | |
| story.append(img_rl) | |
| story.append(Spacer(1, 0.4 * inch)) | |
| # Final report text based on object detection | |
| if filtered_preds: | |
| story.append(Paragraph( | |
| "<b>The X-ray image analysis shows potential fracture localization.</b>", | |
| report_text_style | |
| )) | |
| for pred in filtered_preds: | |
| score = pred['score'] | |
| temp = 36.5 + (score * 2.5) | |
| story.append(Paragraph( | |
| f"Detection: {translate_label(pred['label'])} with {score:.1%} confidence ({temp:.1f}°C)", | |
| report_text_style | |
| )) | |
| else: | |
| story.append(Paragraph( | |
| "<b>Based on object localization analysis, no fracture was detected with sufficient confidence.</b>", | |
| report_text_style | |
| )) | |
| story.append(Spacer(1, 0.2 * inch)) | |
| story.append(Paragraph("This is an automatically generated report and should be reviewed by a medical professional.", centered_style)) | |
| doc.build(story) | |
| buffer.seek(0) | |
| return StreamingResponse(buffer, media_type="application/pdf", | |
| headers={"Content-Disposition": f"attachment; filename=Fracture_Report_{patient_name.replace(' ', '_')}.pdf"}) | |
| except Exception as e: | |
| return HTMLResponse(f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Error</title> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <style> | |
| {COMMON_STYLES} | |
| .error-box {{ | |
| background: #fee2e2; | |
| border: 1px solid #ef4444; | |
| padding: 20px; | |
| border-radius: 8px; | |
| margin: 20px 0; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <div class="error-box"> | |
| <h3>Error</h3> | |
| <p>{str(e)}</p> | |
| </div> | |
| <a href="/" class="button back-button"> | |
| ← Back | |
| <div class="button-progress"></div> | |
| </a> | |
| </div> | |
| </body> | |
| </html> | |
| """) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |