Yash goyal commited on
Commit
d65b1c4
·
verified ·
1 Parent(s): b5cfaf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -165
app.py CHANGED
@@ -1,26 +1,26 @@
1
- from flask import Flask, render_template, request, redirect, url_for, session, send_file
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
 
5
  import pickle
6
- import io
7
  import os
8
- import matplotlib.pyplot as plt
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.lib import colors
11
- from reportlab.pdfgen import canvas
12
- from reportlab.lib.units import inch
13
  from datetime import datetime
14
  import logging
15
 
16
  app = Flask(__name__)
17
- app.secret_key = "e3f6f40bb8b2471b9f07c4025d845be9" # Replace with secure key if needed
18
 
19
- # Paths
20
  MODEL_PATH = "skin_lesion_model.h5"
21
  HISTORY_PATH = "training_history.pkl"
22
  PLOT_PATH = "/tmp/static/training_plot.png"
23
- LOGO_PATH = "static/logo.jpg" # Logo in static folder
 
24
  IMG_SIZE = (224, 224)
25
  CONFIDENCE_THRESHOLD = 0.30
26
 
@@ -35,183 +35,148 @@ label_map = {
35
  7: "Squamous cell carcinoma"
36
  }
37
 
38
- # Logging setup
39
  logging.basicConfig(level=logging.INFO)
40
  logger = logging.getLogger(__name__)
41
 
42
  # Load model
43
- try:
44
- logger.info("Loading model from %s", MODEL_PATH)
45
- model = tf.keras.models.load_model(MODEL_PATH)
46
- except Exception as e:
47
- logger.error("Failed to load model: %s", str(e))
48
- raise
49
-
50
- # Load training history and generate plot
51
- history_dict = {}
52
  if os.path.exists(HISTORY_PATH):
53
- try:
54
- with open(HISTORY_PATH, "rb") as f:
55
- history_dict = pickle.load(f)
 
56
  os.makedirs("/tmp/static", exist_ok=True)
57
- if "accuracy" in history_dict and "val_accuracy" in history_dict:
58
- plt.plot(history_dict['accuracy'], label='Train Accuracy')
59
- plt.plot(history_dict['val_accuracy'], label='Val Accuracy')
60
- plt.xlabel('Epochs')
61
- plt.ylabel('Accuracy')
62
- plt.title('Training History')
63
- plt.legend()
64
- plt.grid(True)
65
- plt.savefig(PLOT_PATH)
66
- plt.close()
67
- logger.info("Training plot saved at %s", PLOT_PATH)
68
- except Exception as e:
69
- logger.error("Failed to process training history: %s", str(e))
70
 
71
  def preprocess_image(image_bytes):
72
- try:
73
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
74
- image = image.resize(IMG_SIZE)
75
- image_array = tf.keras.utils.img_to_array(image)
76
- image_array = np.expand_dims(image_array, axis=0)
77
- return image_array / 255.0
78
- except Exception as e:
79
- logger.error("Image preprocessing failed: %s", str(e))
80
- raise
81
-
82
- def generate_pdf(report_data, filepath):
83
- c = canvas.Canvas(filepath, pagesize=A4)
84
- width, height = A4
85
-
86
- # Add logo if exists
87
- try:
88
- if os.path.exists(LOGO_PATH):
89
- c.drawImage(LOGO_PATH, 50, height - 100, width=80, preserveAspectRatio=True, mask='auto')
90
- except Exception as e:
91
- logger.warning("Could not load logo: %s", str(e))
92
 
93
- # Title
94
- c.setFillColor(colors.HexColor("#007ACC"))
95
- c.setFont("Helvetica-Bold", 20)
96
- c.drawCentredString(width / 2, height - 80, "Skin Lesion Diagnosis Report")
97
- c.setStrokeColor(colors.HexColor("#007ACC"))
98
- c.setLineWidth(2)
99
- c.line(60, height - 90, width - 60, height - 90)
100
-
101
- # Info box background
102
- c.setFillColor(colors.lightgrey)
103
- c.rect(50, height - 250, width - 100, 140, fill=1, stroke=0)
104
-
105
- # Patient Info
106
- c.setFillColor(colors.black)
107
- c.setFont("Helvetica-Bold", 12)
108
- y = height - 120
109
- spacing = 20
110
-
111
- def draw_field(label, value):
112
- nonlocal y
113
- c.setFont("Helvetica-Bold", 12)
114
- c.drawString(70, y, f"{label}:")
115
- c.setFont("Helvetica", 12)
116
- c.drawString(180, y, value)
117
- y -= spacing
118
-
119
- draw_field("Full Name", report_data.get("name", "N/A"))
120
- draw_field("Email", report_data.get("email", "N/A"))
121
- draw_field("Gender", report_data.get("gender", "N/A"))
122
- draw_field("Age", str(report_data.get("age", "N/A")))
123
-
124
- # Prediction
125
- y -= 20
126
- c.setFont("Helvetica-Bold", 14)
127
- c.setFillColor(colors.HexColor("#007ACC"))
128
- c.drawString(50, y, "AI Diagnosis Result")
129
- c.setFillColor(colors.black)
130
- y -= spacing
131
- draw_field("Prediction", report_data.get("prediction", "N/A"))
132
- draw_field("Confidence", report_data.get("confidence", "N/A"))
133
-
134
- # Optional message
135
- message = report_data.get("message", "")
136
- if message:
137
- y -= 10
138
- c.setFont("Helvetica-Oblique", 11)
139
- c.setFillColor(colors.red)
140
- c.drawString(70, y, message)
141
-
142
- # Timestamp
143
- y -= 40
144
- c.setFont("Helvetica", 10)
145
- c.setFillColor(colors.grey)
146
- c.drawString(50, y, f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
147
-
148
- c.save()
149
-
150
- @app.route("/form", methods=["GET"])
151
  def form():
152
- return render_template("form.html", history_plot="/training_plot.png")
 
 
153
 
154
  @app.route("/training_plot.png")
155
  def training_plot():
156
- return send_file(PLOT_PATH, mimetype="image/png")
157
 
158
  @app.route("/predict", methods=["POST"])
159
  def predict():
160
- try:
161
- if "image" not in request.files:
162
- raise ValueError("⚠ No image uploaded.")
163
-
164
- image = request.files["image"].read()
165
- img_array = preprocess_image(image)
166
- prediction = model.predict(img_array)[0]
167
- predicted_index = int(np.argmax(prediction))
168
- confidence = float(prediction[predicted_index])
169
-
170
- name = request.form.get("name")
171
- email = request.form.get("email")
172
- gender = request.form.get("gender")
173
- age = request.form.get("age")
174
-
175
- if confidence < CONFIDENCE_THRESHOLD:
176
- pred_label = "Low confidence"
177
- msg = "⚠ This image is not confidently recognized. Please upload a clearer image."
178
- else:
179
- pred_label = label_map.get(predicted_index, "Unknown")
180
- msg = ""
181
-
182
- session["report"] = {
183
- "name": name,
184
- "email": email,
185
- "gender": gender,
186
- "age": age,
187
- "prediction": pred_label,
188
  "confidence": f"{confidence * 100:.2f}%",
189
- "message": msg
 
 
 
 
 
190
  }
191
 
192
- return redirect(url_for("result"))
193
- except Exception as e:
194
- logger.error("Prediction error: %s", str(e))
195
- return render_template("form.html", history_plot="/training_plot.png", result={
196
- "prediction": "Error",
197
- "confidence": "N/A",
198
- "message": f"An error occurred: {str(e)}"
199
- })
200
-
201
- @app.route("/result")
202
- def result():
203
- report = session.get("report", {})
204
- return render_template("result.html", **report)
205
-
206
- @app.route("/download-report")
207
- def download_report():
208
- report = session.get("report", {})
209
- if not report:
210
  return redirect(url_for("form"))
211
- os.makedirs("/tmp/reports", exist_ok=True)
212
- filepath = "/tmp/reports/report.pdf"
213
- generate_pdf(report, filepath)
214
- return send_file(filepath, as_attachment=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  if __name__ == "__main__":
217
  app.run(host="0.0.0.0", port=7860)
 
1
+ from flask import Flask, render_template, request, send_file, redirect, url_for, session
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
+ import matplotlib.pyplot as plt
6
  import pickle
 
7
  import os
8
+ import io
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.lib import colors
11
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
12
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, Image as RLImage
13
  from datetime import datetime
14
  import logging
15
 
16
  app = Flask(__name__)
17
+ app.secret_key = "your_secret_key_here" # Replace with a random string
18
 
 
19
  MODEL_PATH = "skin_lesion_model.h5"
20
  HISTORY_PATH = "training_history.pkl"
21
  PLOT_PATH = "/tmp/static/training_plot.png"
22
+ LOGO_PATH = "static/logo.jpg"
23
+
24
  IMG_SIZE = (224, 224)
25
  CONFIDENCE_THRESHOLD = 0.30
26
 
 
35
  7: "Squamous cell carcinoma"
36
  }
37
 
 
38
  logging.basicConfig(level=logging.INFO)
39
  logger = logging.getLogger(__name__)
40
 
41
  # Load model
42
+ logger.info("Loading model from %s", MODEL_PATH)
43
+ model = tf.keras.models.load_model(MODEL_PATH)
44
+
45
+ # Load and plot training history
 
 
 
 
 
46
  if os.path.exists(HISTORY_PATH):
47
+ with open(HISTORY_PATH, "rb") as f:
48
+ history_dict = pickle.load(f)
49
+
50
+ if "accuracy" in history_dict:
51
  os.makedirs("/tmp/static", exist_ok=True)
52
+ plt.plot(history_dict['accuracy'], label='Train Accuracy')
53
+ plt.plot(history_dict.get('val_accuracy', []), label='Val Accuracy')
54
+ plt.xlabel("Epoch")
55
+ plt.ylabel("Accuracy")
56
+ plt.title("Model Training History")
57
+ plt.legend()
58
+ plt.grid(True)
59
+ plt.savefig(PLOT_PATH)
60
+ plt.close()
61
+ logger.info("Training plot saved at %s", PLOT_PATH)
 
 
 
62
 
63
  def preprocess_image(image_bytes):
64
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
65
+ image = image.resize(IMG_SIZE)
66
+ image_array = tf.keras.utils.img_to_array(image)
67
+ image_array = np.expand_dims(image_array, axis=0)
68
+ return image_array / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ @app.route("/form")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def form():
72
+ result = session.pop("result", None)
73
+ patient = session.pop("patient", None)
74
+ return render_template("form.html", history_plot="/training_plot.png", result=result, patient=patient)
75
 
76
  @app.route("/training_plot.png")
77
  def training_plot():
78
+ return send_file(PLOT_PATH, mimetype='image/png')
79
 
80
  @app.route("/predict", methods=["POST"])
81
  def predict():
82
+ if "image" not in request.files:
83
+ return redirect(url_for("form"))
84
+
85
+ name = request.form.get("name")
86
+ email = request.form.get("email")
87
+ gender = request.form.get("gender")
88
+ age = request.form.get("age")
89
+ image = request.files["image"].read()
90
+
91
+ img_array = preprocess_image(image)
92
+ prediction = model.predict(img_array)[0]
93
+ predicted_index = int(np.argmax(prediction))
94
+ confidence = float(prediction[predicted_index])
95
+
96
+ if confidence < CONFIDENCE_THRESHOLD:
97
+ result = {
98
+ "prediction": "Low confidence",
 
 
 
 
 
 
 
 
 
 
 
99
  "confidence": f"{confidence * 100:.2f}%",
100
+ "message": "⚠️ Image not confidently classified. Try uploading a clearer image."
101
+ }
102
+ else:
103
+ result = {
104
+ "prediction": label_map.get(predicted_index, "Unknown"),
105
+ "confidence": f"{confidence * 100:.2f}%"
106
  }
107
 
108
+ session["result"] = result
109
+ session["patient"] = {"name": name, "email": email, "gender": gender, "age": age}
110
+ return redirect(url_for("form"))
111
+
112
+ @app.route("/download-pdf")
113
+ def download_pdf():
114
+ patient = session.get("patient")
115
+ result = session.get("result")
116
+ if not patient or not result:
 
 
 
 
 
 
 
 
 
117
  return redirect(url_for("form"))
118
+
119
+ buffer = io.BytesIO()
120
+ doc = SimpleDocTemplate(buffer, pagesize=A4)
121
+
122
+ elements = []
123
+ styles = getSampleStyleSheet()
124
+ styles.add(ParagraphStyle(name="Title", fontSize=20, textColor=colors.HexColor("#007acc"), spaceAfter=16, alignment=1))
125
+ styles.add(ParagraphStyle(name="SectionHeader", fontSize=14, textColor=colors.HexColor("#007acc"), spaceBefore=10, spaceAfter=10))
126
+ styles.add(ParagraphStyle(name="NormalBold", fontSize=12, leading=14, spaceAfter=6, fontName='Helvetica-Bold'))
127
+
128
+ # Logo
129
+ if os.path.exists(LOGO_PATH):
130
+ elements.append(RLImage(LOGO_PATH, width=100, height=50))
131
+ elements.append(Spacer(1, 12))
132
+
133
+ # Title
134
+ elements.append(Paragraph("Skin Lesion Diagnosis Report", styles["Title"]))
135
+ elements.append(Spacer(1, 6))
136
+
137
+ # Patient Info Table
138
+ patient_data = [
139
+ ["Full Name:", patient["name"]],
140
+ ["Email:", patient["email"]],
141
+ ["Gender:", patient["gender"]],
142
+ ["Age:", patient["age"]],
143
+ ]
144
+ patient_table = Table(patient_data, colWidths=[100, 300])
145
+ patient_table.setStyle(TableStyle([
146
+ ("BACKGROUND", (0, 0), (-1, -1), colors.whitesmoke),
147
+ ("FONTNAME", (0, 0), (-1, -1), "Helvetica"),
148
+ ("FONTSIZE", (0, 0), (-1, -1), 11),
149
+ ("BOTTOMPADDING", (0, 0), (-1, -1), 8),
150
+ ("ROWBACKGROUNDS", (0, 0), (-1, -1), [colors.lightgrey, colors.whitesmoke]),
151
+ ]))
152
+ elements.append(Paragraph("Patient Information", styles["SectionHeader"]))
153
+ elements.append(patient_table)
154
+
155
+ # Prediction Info Table
156
+ result_data = [
157
+ ["Prediction:", result["prediction"]],
158
+ ["Confidence:", result["confidence"]],
159
+ ]
160
+ result_table = Table(result_data, colWidths=[100, 300])
161
+ result_table.setStyle(TableStyle([
162
+ ("BACKGROUND", (0, 0), (-1, -1), colors.whitesmoke),
163
+ ("FONTNAME", (0, 0), (-1, -1), "Helvetica"),
164
+ ("FONTSIZE", (0, 0), (-1, -1), 11),
165
+ ("BOTTOMPADDING", (0, 0), (-1, -1), 8),
166
+ ("ROWBACKGROUNDS", (0, 0), (-1, -1), [colors.lightgrey, colors.whitesmoke]),
167
+ ]))
168
+ elements.append(Spacer(1, 16))
169
+ elements.append(Paragraph("AI Diagnosis Result", styles["SectionHeader"]))
170
+ elements.append(result_table)
171
+
172
+ # Footer
173
+ elements.append(Spacer(1, 30))
174
+ date_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
175
+ elements.append(Paragraph(f"<font size='10'>Generated on: {date_str}</font>", styles["Normal"]))
176
+
177
+ doc.build(elements)
178
+ buffer.seek(0)
179
+ return send_file(buffer, as_attachment=True, download_name="diagnosis_report.pdf", mimetype="application/pdf")
180
 
181
  if __name__ == "__main__":
182
  app.run(host="0.0.0.0", port=7860)