Instructions to use agcaabdurrahim/tumor_model with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use agcaabdurrahim/tumor_model with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://agcaabdurrahim/tumor_model") - Notebooks
- Google Colab
- Kaggle
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import io | |
| import matplotlib | |
| matplotlib.use('Agg') # Set the backend to Agg before importing pyplot | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| from tensorflow.keras.models import load_model | |
| import tempfile | |
| import random | |
| import os | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| #todo: change to allow only the frontend domain | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| model = load_model('my_model.keras') | |
| def predict_and_plot(img): | |
| class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3} | |
| label = list(class_dict.keys()) | |
| plt.figure(figsize=(16, 12)) # Increased width from 12 to 16 | |
| resized_img = img.resize((299, 299)) | |
| img_array = np.asarray(resized_img) | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack((img_array,) * 3, axis=-1) | |
| elif img_array.shape[2] == 4: | |
| img_array = img_array[:, :, :3] | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = img_array / 255.0 | |
| predictions = model.predict(img_array) | |
| probs = list(predictions[0]) | |
| # Get the highest probability prediction | |
| max_prob_idx = np.argmax(probs) | |
| prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})" | |
| plt.subplot(2, 1, 1) | |
| plt.imshow(resized_img) | |
| plt.title('Input Image', fontsize=16, pad=20) | |
| plt.axis('off') | |
| plt.subplot(2, 1, 2) | |
| bars = plt.barh(label, probs) | |
| plt.xlabel('Probability', fontsize=14) | |
| plt.title('Prediction Probabilities', fontsize=16, pad=20) | |
| ax = plt.gca() | |
| ax.bar_label(bars, fmt='%.2f', fontsize=12) | |
| plt.xlim(0, 1.1) # Set x-axis limit to accommodate labels | |
| plt.tight_layout() # Adjust layout to prevent label cutoff | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
| plt.savefig(temp_file.name, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return temp_file.name, prediction_text | |
| def get_random_image(): | |
| # Set a random seed based on current time | |
| random.seed() | |
| print(f"Random seed: {random.getstate()[1][0]}") | |
| # Get the absolute path of the Testing directory | |
| base_dir = os.path.abspath(os.path.dirname(__file__)) | |
| testing_dir = os.path.join(base_dir, 'Testing') | |
| print(f"Testing directory: {testing_dir}") | |
| # Get all subdirectories in Testing | |
| subdirs = [d for d in os.listdir(testing_dir) if os.path.isdir(os.path.join(testing_dir, d))] | |
| print(f"Available subdirectories: {subdirs}") | |
| # Randomly select a subdirectory | |
| random_subdir = random.choice(subdirs) | |
| print(f"Selected subdirectory: {random_subdir}") | |
| # Get all images in the selected subdirectory | |
| subdir_path = os.path.join(testing_dir, random_subdir) | |
| images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
| print(f"Found {len(images)} images in {random_subdir}") | |
| print(f"First few images: {images[:5]}") | |
| if not images: | |
| raise Exception(f"No images found in {random_subdir}") | |
| # Randomly select an image | |
| random_image = random.choice(images) | |
| print(f"Selected image: {random_image}") | |
| # Return full path | |
| full_path = os.path.join(subdir_path, random_image) | |
| print(f"Full path: {full_path}") | |
| return full_path | |
| async def get_random_image_endpoint(): | |
| try: | |
| random_image_path = get_random_image() | |
| if not os.path.exists(random_image_path): | |
| raise Exception(f"Image file not found: {random_image_path}") | |
| return FileResponse( | |
| random_image_path, | |
| media_type="image/png", | |
| headers={ | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "GET, OPTIONS", | |
| "Access-Control-Allow-Headers": "*", | |
| "Cache-Control": "no-cache, no-store, must-revalidate", | |
| "Pragma": "no-cache", | |
| "Expires": "0" | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"Error getting random image: {str(e)}") | |
| raise | |
| async def predict_image(file: UploadFile = File(...)): | |
| try: | |
| contents = await file.read() | |
| img = Image.open(io.BytesIO(contents)) | |
| result_path, prediction_text = predict_and_plot(img) | |
| return FileResponse( | |
| result_path, | |
| media_type="image/png", | |
| headers={ | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "*", | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| raise | |
| async def predict_text(file: UploadFile = File(...)): | |
| try: | |
| contents = await file.read() | |
| img = Image.open(io.BytesIO(contents)) | |
| resized_img = img.resize((299, 299)) | |
| img_array = np.asarray(resized_img) | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack((img_array,) * 3, axis=-1) | |
| elif img_array.shape[2] == 4: | |
| img_array = img_array[:, :, :3] | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = img_array / 255.0 | |
| predictions = model.predict(img_array) | |
| probs = list(predictions[0]) | |
| class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3} | |
| label = list(class_dict.keys()) | |
| max_prob_idx = np.argmax(probs) | |
| prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})" | |
| return {"prediction": prediction_text} | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |