Dermo_AI / app.py
mani880740255's picture
Update app.py
327b96a verified
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
import os, io, tempfile, requests
import numpy as np
from PIL import Image
import cv2
import tensorflow as tf
import torch
from torchvision import models, transforms
# =======================
# LLaMA CPP (CPU FAST)
# =======================
from llama_cpp import Llama
# =====================
# APP CONFIG
# =====================
app = Flask(__name__)
CORS(app)
UPLOAD_FOLDER = "uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
CLASS_LABELS = ["benign", "malignant", "normal"]
ALLOWED_EXT = {"jpg", "jpeg", "png"}
device = "cpu"
# =====================
# HUGGING FACE MODELS
# =====================
HF_BASE = "https://huggingface.co/mani880740255/skin_care_tflite/resolve/main/"
HF_MODELS = {
"tflite": HF_BASE + "skin_model_quantized.tflite",
"mobilenetv2": HF_BASE + "skin_cancer_mobilenetv2%20(1).h5",
"b3": HF_BASE + "efficientnet_b3_skin_cancer.pth"
}
# =====================
# TINYLLAMA GGUF CONFIG
# =====================
LLM_PATH = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf"
print("🔄 Loading TinyLlama GGUF (CPU)...")
llm = Llama(
model_path=LLM_PATH,
n_ctx=512,
n_threads=4,
n_batch=128,
verbose=False
)
print("✅ TinyLlama loaded")
SYSTEM_PROMPT = (
"You are a skin health assistant. "
"Do not diagnose diseases. "
"Explain in simple language. "
"Give general precautions. "
"Always recommend consulting a dermatologist. "
"Add a medical disclaimer."
)
# =====================
# HELPERS
# =====================
def allowed_file(name):
return "." in name and name.rsplit(".", 1)[1].lower() in ALLOWED_EXT
def download_file(url):
r = requests.get(url)
if r.status_code != 200:
raise Exception(f"Model download failed: {url}")
return io.BytesIO(r.content)
# =====================
# IMAGE MODELS
# =====================
def predict_tflite(img_path):
model_bytes = download_file(HF_MODELS["tflite"])
interpreter = tf.lite.Interpreter(model_content=model_bytes.read())
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
img = img.astype("float32") / 255.0
img = np.expand_dims(img, axis=0)
interpreter.set_tensor(input_details[0]["index"], img)
interpreter.invoke()
preds = interpreter.get_tensor(output_details[0]["index"])[0]
idx = int(np.argmax(preds))
return CLASS_LABELS[idx], float(preds[idx]), preds.tolist()
def predict_keras(img_path):
model_bytes = download_file(HF_MODELS["mobilenetv2"])
with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp:
tmp.write(model_bytes.read())
tmp_path = tmp.name
try:
model = tf.keras.models.load_model(tmp_path)
img = Image.open(img_path).convert("RGB")
img = img.resize((224, 224))
img = np.array(img) / 255.0
img = np.expand_dims(img, axis=0)
preds = model.predict(img, verbose=0)[0]
idx = int(np.argmax(preds))
return CLASS_LABELS[idx], float(preds[idx]), preds.tolist()
finally:
os.remove(tmp_path)
def predict_b3(img_path):
model_bytes = download_file(HF_MODELS["b3"])
model = models.efficientnet_b3(weights=None)
model.classifier[1] = torch.nn.Linear(1536, 3)
with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp:
tmp.write(model_bytes.read())
tmp_path = tmp.name
try:
model.load_state_dict(torch.load(tmp_path, map_location="cpu"))
model.eval()
transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor()
])
img = Image.open(img_path).convert("RGB")
img = transform(img).unsqueeze(0)
with torch.no_grad():
out = model(img)
probs = torch.softmax(out, dim=1)[0]
idx = int(torch.argmax(probs))
return CLASS_LABELS[idx], float(probs[idx]), probs.tolist()
finally:
os.remove(tmp_path)
# =====================
# CHATBOT (FAST)
# =====================
def llm_chat_response(user_message, prediction=None, confidence=None):
context = ""
if prediction and confidence:
context = f"AI result: {prediction} ({confidence*100:.1f}%)."
prompt = f"""
<|system|>
{SYSTEM_PROMPT}
{context}
<|user|>
{user_message}
<|assistant|>
"""
output = llm(
prompt,
max_tokens=120,
temperature=0.2,
top_p=0.9,
stop=["<|user|>"]
)
return output["choices"][0]["text"].strip()
# =====================
# ROUTES
# =====================
@app.route("/")
def home():
return render_template("index.html")
@app.route("/predict", methods=["POST"])
def predict():
if "image" not in request.files or "model" not in request.form:
return jsonify({"error": "image + model required"}), 400
model_choice = request.form["model"]
file = request.files["image"]
if model_choice not in HF_MODELS or not allowed_file(file.filename):
return jsonify({"error": "invalid model or file"}), 400
path = os.path.join(UPLOAD_FOLDER, file.filename)
file.save(path)
try:
if model_choice == "tflite":
pred, conf, probs = predict_tflite(path)
elif model_choice == "mobilenetv2":
pred, conf, probs = predict_keras(path)
else:
pred, conf, probs = predict_b3(path)
return jsonify({
"model_used": model_choice,
"prediction": pred,
"confidence": conf,
"probabilities": {
CLASS_LABELS[i]: probs[i] for i in range(3)
}
})
finally:
os.remove(path)
@app.route("/chat", methods=["POST"])
def chat():
data = request.get_json()
user_msg = data.get("message", "").strip()
if not user_msg:
return jsonify({"reply": "Please ask a skin health related question."})
reply = llm_chat_response(
user_msg,
data.get("prediction"),
data.get("confidence")
)
return jsonify({
"reply": reply,
"disclaimer": "⚠️ This chatbot is for educational purposes only and not a medical diagnosis."
})
# =====================
# LOCAL RUN
# =====================
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)