Satyam-Singh commited on
Commit
9da3d7c
·
verified ·
1 Parent(s): 7d6648f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import json
5
+ import datetime
6
+ import importlib.util
7
+ from typing import List, Optional
8
+
9
+ from fastapi import FastAPI, Request
10
+ from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
11
+ from fastapi.staticfiles import StaticFiles
12
+ from pydantic import BaseModel
13
+ from PIL import Image
14
+
15
+
16
+ def _load_rmmm_module():
17
+ """Dynamically load the project's RMMM module so we can reuse model
18
+ loading and inference code from `XRaySwinGen-RMMM/app.py`.
19
+ """
20
+ root = os.path.dirname(__file__)
21
+ module_path = os.path.normpath(os.path.join(root, "XRaySwinGen-RMMM", "app.py"))
22
+ spec = importlib.util.spec_from_file_location("rmmm_module", module_path)
23
+ module = importlib.util.module_from_spec(spec)
24
+ spec.loader.exec_module(module)
25
+ return module
26
+
27
+
28
+ # Load module at startup (this will run top-level initialization in that file)
29
+ RMMM = _load_rmmm_module()
30
+
31
+
32
+ app = FastAPI(title="XRaySwinGen - RMMM API for Static Frontend")
33
+
34
+ # Serve the static frontend contained in XRaySwinGen-RMMM/frontend
35
+ frontend_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "XRaySwinGen-RMMM", "frontend"))
36
+ app.mount("/", StaticFiles(directory=frontend_dir, html=True), name="frontend")
37
+
38
+
39
+ class ReportRequest(BaseModel):
40
+ image_base64: Optional[str] = None
41
+ filename: Optional[str] = None
42
+ source_id: Optional[str] = None
43
+
44
+
45
+ def pil_to_data_url(img: Image.Image, fmt: str = "PNG") -> str:
46
+ buf = io.BytesIO()
47
+ img.save(buf, format=fmt)
48
+ b = base64.b64encode(buf.getvalue()).decode("ascii")
49
+ return f"data:image/{fmt.lower()};base64,{b}"
50
+
51
+
52
+ @app.get("/api/v1/samples")
53
+ async def get_samples(limit: int = 24):
54
+ """Return a list of sample images (base64) found in XRaySwinGen-RMMM/images"""
55
+ images_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "XRaySwinGen-RMMM", "images"))
56
+ items = []
57
+ if not os.path.isdir(images_dir):
58
+ return JSONResponse({"items": []})
59
+
60
+ files = [f for f in os.listdir(images_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]
61
+ files = sorted(files)[:limit]
62
+ for fname in files:
63
+ path = os.path.join(images_dir, fname)
64
+ try:
65
+ with Image.open(path) as im:
66
+ data_url = pil_to_data_url(im.convert("RGB"), fmt="PNG")
67
+ item = {
68
+ "id": os.path.splitext(fname)[0],
69
+ "title": os.path.splitext(fname)[0],
70
+ "image_base64": data_url,
71
+ }
72
+ items.append(item)
73
+ except Exception:
74
+ continue
75
+
76
+ return JSONResponse({"items": items})
77
+
78
+
79
+ @app.post("/api/v1/report")
80
+ async def create_report(req: ReportRequest):
81
+ """Generate a report from an uploaded image or a sample. The frontend
82
+ expects a JSON response containing fields used by the static UI.
83
+ """
84
+ try:
85
+ # Determine image source
86
+ pil_img = None
87
+ selected_image_path = ""
88
+
89
+ if req.source_id:
90
+ # Try to find a matching file in images/ by id
91
+ images_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "XRaySwinGen-RMMM", "images"))
92
+ candidates = [f for f in os.listdir(images_dir) if os.path.splitext(f)[0] == req.source_id]
93
+ if candidates:
94
+ selected_image_path = os.path.join(images_dir, candidates[0])
95
+ pil_img = Image.open(selected_image_path).convert("RGB")
96
+
97
+ if req.image_base64 and not pil_img:
98
+ # image_base64 may be a data URL
99
+ data = req.image_base64
100
+ if "," in data:
101
+ data = data.split(",", 1)[1]
102
+ raw = base64.b64decode(data)
103
+ pil_img = Image.open(io.BytesIO(raw)).convert("RGB")
104
+
105
+ if pil_img is None:
106
+ return JSONResponse({"error": "No valid image provided"}, status_code=400)
107
+
108
+ # Run inference using functions from the loaded module
109
+ ai_report = None
110
+ ground_truth = ""
111
+ metrics_html = ""
112
+
113
+ if hasattr(RMMM, "inference_image_pipe_with_state"):
114
+ ai_report, ground_truth, metrics_html = RMMM.inference_image_pipe_with_state(pil_img, selected_image_path)
115
+ else:
116
+ ai_report = RMMM.inference_torch_model_fast(pil_img)
117
+ if hasattr(RMMM, "get_ground_truth_from_filename"):
118
+ ground_truth = RMMM.get_ground_truth_from_filename(selected_image_path)
119
+
120
+ # Annotated image
121
+ annotated_pil = None
122
+ if hasattr(RMMM, "annotate_image"):
123
+ try:
124
+ annotated_pil = RMMM.annotate_image(pil_img.copy(), ai_report or "")
125
+ except Exception:
126
+ annotated_pil = pil_img
127
+ else:
128
+ annotated_pil = pil_img
129
+
130
+ # Explanation
131
+ explanation = ""
132
+ detailed = ""
133
+ step_html = ""
134
+ if hasattr(RMMM, "explain_findings"):
135
+ try:
136
+ explanation, detailed, step_html = RMMM.explain_findings(ai_report or "", ground_truth or "")
137
+ except Exception:
138
+ explanation = ""
139
+
140
+ # Compute numeric metrics if possible
141
+ numeric_metrics = None
142
+ if hasattr(RMMM, "calculate_evaluation_metrics") and ground_truth:
143
+ try:
144
+ numeric_metrics = RMMM.calculate_evaluation_metrics(ai_report or "", ground_truth or "")
145
+ numeric_metrics = {k: v for k, v in numeric_metrics.items() if k in ("bleu4_score", "rougeL_f")}
146
+ except Exception:
147
+ numeric_metrics = None
148
+
149
+ # Build small findings/insights from explanation (best-effort)
150
+ insights = []
151
+ findings = {}
152
+ if explanation:
153
+ for line in explanation.split("\n")[:6]:
154
+ if line.strip():
155
+ insights.append(line.strip())
156
+
157
+ # Convert annotated image to data URL
158
+ annotated_data_url = pil_to_data_url(annotated_pil, fmt="PNG")
159
+
160
+ now = datetime.datetime.utcnow().isoformat() + "Z"
161
+
162
+ response = {
163
+ "created_at": now,
164
+ "report_text": ai_report or "",
165
+ "annotated_image": annotated_data_url,
166
+ "insights": insights,
167
+ "findings": findings,
168
+ "metrics": numeric_metrics,
169
+ "status_chain": [
170
+ {"title": "Uploaded", "detail": req.filename or req.source_id or "uploaded image"},
171
+ {"title": "Preprocessing", "detail": "Resizing, normalization"},
172
+ {"title": "Inference", "detail": "RMMM model inference executed"},
173
+ {"title": "Postprocessing", "detail": "Decoding and annotations"},
174
+ ],
175
+ "source_id": req.source_id or os.path.splitext(req.filename or "")[0],
176
+ }
177
+
178
+ return JSONResponse(response)
179
+
180
+ except Exception as e:
181
+ return JSONResponse({"error": str(e)}, status_code=500)