Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, render_template | |
| from flask_cors import CORS | |
| import os, io, tempfile, requests | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import tensorflow as tf | |
| import torch | |
| from torchvision import models, transforms | |
| # ======================= | |
| # LLaMA CPP (CPU FAST) | |
| # ======================= | |
| from llama_cpp import Llama | |
| # ===================== | |
| # APP CONFIG | |
| # ===================== | |
| app = Flask(__name__) | |
| CORS(app) | |
| UPLOAD_FOLDER = "uploads" | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| CLASS_LABELS = ["benign", "malignant", "normal"] | |
| ALLOWED_EXT = {"jpg", "jpeg", "png"} | |
| device = "cpu" | |
| # ===================== | |
| # HUGGING FACE MODELS | |
| # ===================== | |
| HF_BASE = "https://huggingface.co/mani880740255/skin_care_tflite/resolve/main/" | |
| HF_MODELS = { | |
| "tflite": HF_BASE + "skin_model_quantized.tflite", | |
| "mobilenetv2": HF_BASE + "skin_cancer_mobilenetv2%20(1).h5", | |
| "b3": HF_BASE + "efficientnet_b3_skin_cancer.pth" | |
| } | |
| # ===================== | |
| # TINYLLAMA GGUF CONFIG | |
| # ===================== | |
| LLM_PATH = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf" | |
| print("🔄 Loading TinyLlama GGUF (CPU)...") | |
| llm = Llama( | |
| model_path=LLM_PATH, | |
| n_ctx=512, | |
| n_threads=4, | |
| n_batch=128, | |
| verbose=False | |
| ) | |
| print("✅ TinyLlama loaded") | |
| SYSTEM_PROMPT = ( | |
| "You are a skin health assistant. " | |
| "Do not diagnose diseases. " | |
| "Explain in simple language. " | |
| "Give general precautions. " | |
| "Always recommend consulting a dermatologist. " | |
| "Add a medical disclaimer." | |
| ) | |
| # ===================== | |
| # HELPERS | |
| # ===================== | |
| def allowed_file(name): | |
| return "." in name and name.rsplit(".", 1)[1].lower() in ALLOWED_EXT | |
| def download_file(url): | |
| r = requests.get(url) | |
| if r.status_code != 200: | |
| raise Exception(f"Model download failed: {url}") | |
| return io.BytesIO(r.content) | |
| # ===================== | |
| # IMAGE MODELS | |
| # ===================== | |
| def predict_tflite(img_path): | |
| model_bytes = download_file(HF_MODELS["tflite"]) | |
| interpreter = tf.lite.Interpreter(model_content=model_bytes.read()) | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| img = cv2.imread(img_path) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, (224, 224)) | |
| img = img.astype("float32") / 255.0 | |
| img = np.expand_dims(img, axis=0) | |
| interpreter.set_tensor(input_details[0]["index"], img) | |
| interpreter.invoke() | |
| preds = interpreter.get_tensor(output_details[0]["index"])[0] | |
| idx = int(np.argmax(preds)) | |
| return CLASS_LABELS[idx], float(preds[idx]), preds.tolist() | |
| def predict_keras(img_path): | |
| model_bytes = download_file(HF_MODELS["mobilenetv2"]) | |
| with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp: | |
| tmp.write(model_bytes.read()) | |
| tmp_path = tmp.name | |
| try: | |
| model = tf.keras.models.load_model(tmp_path) | |
| img = Image.open(img_path).convert("RGB") | |
| img = img.resize((224, 224)) | |
| img = np.array(img) / 255.0 | |
| img = np.expand_dims(img, axis=0) | |
| preds = model.predict(img, verbose=0)[0] | |
| idx = int(np.argmax(preds)) | |
| return CLASS_LABELS[idx], float(preds[idx]), preds.tolist() | |
| finally: | |
| os.remove(tmp_path) | |
| def predict_b3(img_path): | |
| model_bytes = download_file(HF_MODELS["b3"]) | |
| model = models.efficientnet_b3(weights=None) | |
| model.classifier[1] = torch.nn.Linear(1536, 3) | |
| with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp: | |
| tmp.write(model_bytes.read()) | |
| tmp_path = tmp.name | |
| try: | |
| model.load_state_dict(torch.load(tmp_path, map_location="cpu")) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((300, 300)), | |
| transforms.ToTensor() | |
| ]) | |
| img = Image.open(img_path).convert("RGB") | |
| img = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| out = model(img) | |
| probs = torch.softmax(out, dim=1)[0] | |
| idx = int(torch.argmax(probs)) | |
| return CLASS_LABELS[idx], float(probs[idx]), probs.tolist() | |
| finally: | |
| os.remove(tmp_path) | |
| # ===================== | |
| # CHATBOT (FAST) | |
| # ===================== | |
| def llm_chat_response(user_message, prediction=None, confidence=None): | |
| context = "" | |
| if prediction and confidence: | |
| context = f"AI result: {prediction} ({confidence*100:.1f}%)." | |
| prompt = f""" | |
| <|system|> | |
| {SYSTEM_PROMPT} | |
| {context} | |
| <|user|> | |
| {user_message} | |
| <|assistant|> | |
| """ | |
| output = llm( | |
| prompt, | |
| max_tokens=120, | |
| temperature=0.2, | |
| top_p=0.9, | |
| stop=["<|user|>"] | |
| ) | |
| return output["choices"][0]["text"].strip() | |
| # ===================== | |
| # ROUTES | |
| # ===================== | |
| def home(): | |
| return render_template("index.html") | |
| def predict(): | |
| if "image" not in request.files or "model" not in request.form: | |
| return jsonify({"error": "image + model required"}), 400 | |
| model_choice = request.form["model"] | |
| file = request.files["image"] | |
| if model_choice not in HF_MODELS or not allowed_file(file.filename): | |
| return jsonify({"error": "invalid model or file"}), 400 | |
| path = os.path.join(UPLOAD_FOLDER, file.filename) | |
| file.save(path) | |
| try: | |
| if model_choice == "tflite": | |
| pred, conf, probs = predict_tflite(path) | |
| elif model_choice == "mobilenetv2": | |
| pred, conf, probs = predict_keras(path) | |
| else: | |
| pred, conf, probs = predict_b3(path) | |
| return jsonify({ | |
| "model_used": model_choice, | |
| "prediction": pred, | |
| "confidence": conf, | |
| "probabilities": { | |
| CLASS_LABELS[i]: probs[i] for i in range(3) | |
| } | |
| }) | |
| finally: | |
| os.remove(path) | |
| def chat(): | |
| data = request.get_json() | |
| user_msg = data.get("message", "").strip() | |
| if not user_msg: | |
| return jsonify({"reply": "Please ask a skin health related question."}) | |
| reply = llm_chat_response( | |
| user_msg, | |
| data.get("prediction"), | |
| data.get("confidence") | |
| ) | |
| return jsonify({ | |
| "reply": reply, | |
| "disclaimer": "⚠️ This chatbot is for educational purposes only and not a medical diagnosis." | |
| }) | |
| # ===================== | |
| # LOCAL RUN | |
| # ===================== | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |