from flask import Flask, request, jsonify, render_template from flask_cors import CORS import os, io, tempfile, requests import numpy as np from PIL import Image import cv2 import tensorflow as tf import torch from torchvision import models, transforms # ======================= # LLaMA CPP (CPU FAST) # ======================= from llama_cpp import Llama # ===================== # APP CONFIG # ===================== app = Flask(__name__) CORS(app) UPLOAD_FOLDER = "uploads" os.makedirs(UPLOAD_FOLDER, exist_ok=True) CLASS_LABELS = ["benign", "malignant", "normal"] ALLOWED_EXT = {"jpg", "jpeg", "png"} device = "cpu" # ===================== # HUGGING FACE MODELS # ===================== HF_BASE = "https://huggingface.co/mani880740255/skin_care_tflite/resolve/main/" HF_MODELS = { "tflite": HF_BASE + "skin_model_quantized.tflite", "mobilenetv2": HF_BASE + "skin_cancer_mobilenetv2%20(1).h5", "b3": HF_BASE + "efficientnet_b3_skin_cancer.pth" } # ===================== # TINYLLAMA GGUF CONFIG # ===================== LLM_PATH = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf" print("🔄 Loading TinyLlama GGUF (CPU)...") llm = Llama( model_path=LLM_PATH, n_ctx=512, n_threads=4, n_batch=128, verbose=False ) print("✅ TinyLlama loaded") SYSTEM_PROMPT = ( "You are a skin health assistant. " "Do not diagnose diseases. " "Explain in simple language. " "Give general precautions. " "Always recommend consulting a dermatologist. " "Add a medical disclaimer." ) # ===================== # HELPERS # ===================== def allowed_file(name): return "." in name and name.rsplit(".", 1)[1].lower() in ALLOWED_EXT def download_file(url): r = requests.get(url) if r.status_code != 200: raise Exception(f"Model download failed: {url}") return io.BytesIO(r.content) # ===================== # IMAGE MODELS # ===================== def predict_tflite(img_path): model_bytes = download_file(HF_MODELS["tflite"]) interpreter = tf.lite.Interpreter(model_content=model_bytes.read()) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (224, 224)) img = img.astype("float32") / 255.0 img = np.expand_dims(img, axis=0) interpreter.set_tensor(input_details[0]["index"], img) interpreter.invoke() preds = interpreter.get_tensor(output_details[0]["index"])[0] idx = int(np.argmax(preds)) return CLASS_LABELS[idx], float(preds[idx]), preds.tolist() def predict_keras(img_path): model_bytes = download_file(HF_MODELS["mobilenetv2"]) with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp: tmp.write(model_bytes.read()) tmp_path = tmp.name try: model = tf.keras.models.load_model(tmp_path) img = Image.open(img_path).convert("RGB") img = img.resize((224, 224)) img = np.array(img) / 255.0 img = np.expand_dims(img, axis=0) preds = model.predict(img, verbose=0)[0] idx = int(np.argmax(preds)) return CLASS_LABELS[idx], float(preds[idx]), preds.tolist() finally: os.remove(tmp_path) def predict_b3(img_path): model_bytes = download_file(HF_MODELS["b3"]) model = models.efficientnet_b3(weights=None) model.classifier[1] = torch.nn.Linear(1536, 3) with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp: tmp.write(model_bytes.read()) tmp_path = tmp.name try: model.load_state_dict(torch.load(tmp_path, map_location="cpu")) model.eval() transform = transforms.Compose([ transforms.Resize((300, 300)), transforms.ToTensor() ]) img = Image.open(img_path).convert("RGB") img = transform(img).unsqueeze(0) with torch.no_grad(): out = model(img) probs = torch.softmax(out, dim=1)[0] idx = int(torch.argmax(probs)) return CLASS_LABELS[idx], float(probs[idx]), probs.tolist() finally: os.remove(tmp_path) # ===================== # CHATBOT (FAST) # ===================== def llm_chat_response(user_message, prediction=None, confidence=None): context = "" if prediction and confidence: context = f"AI result: {prediction} ({confidence*100:.1f}%)." prompt = f""" <|system|> {SYSTEM_PROMPT} {context} <|user|> {user_message} <|assistant|> """ output = llm( prompt, max_tokens=120, temperature=0.2, top_p=0.9, stop=["<|user|>"] ) return output["choices"][0]["text"].strip() # ===================== # ROUTES # ===================== @app.route("/") def home(): return render_template("index.html") @app.route("/predict", methods=["POST"]) def predict(): if "image" not in request.files or "model" not in request.form: return jsonify({"error": "image + model required"}), 400 model_choice = request.form["model"] file = request.files["image"] if model_choice not in HF_MODELS or not allowed_file(file.filename): return jsonify({"error": "invalid model or file"}), 400 path = os.path.join(UPLOAD_FOLDER, file.filename) file.save(path) try: if model_choice == "tflite": pred, conf, probs = predict_tflite(path) elif model_choice == "mobilenetv2": pred, conf, probs = predict_keras(path) else: pred, conf, probs = predict_b3(path) return jsonify({ "model_used": model_choice, "prediction": pred, "confidence": conf, "probabilities": { CLASS_LABELS[i]: probs[i] for i in range(3) } }) finally: os.remove(path) @app.route("/chat", methods=["POST"]) def chat(): data = request.get_json() user_msg = data.get("message", "").strip() if not user_msg: return jsonify({"reply": "Please ask a skin health related question."}) reply = llm_chat_response( user_msg, data.get("prediction"), data.get("confidence") ) return jsonify({ "reply": reply, "disclaimer": "⚠️ This chatbot is for educational purposes only and not a medical diagnosis." }) # ===================== # LOCAL RUN # ===================== if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)