| |
| import os |
| import json |
| import logging |
| import sys |
| from fastapi import FastAPI, HTTPException, Request, Form |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| from pydantic import BaseModel |
| from datetime import datetime |
| from datasets import Dataset, load_dataset, concatenate_datasets |
| from typing import Dict, Optional, Any, List |
| import uuid |
| import re |
| import html |
| from urllib.parse import urlparse |
| from starlette.middleware.base import BaseHTTPMiddleware |
| from huggingface_hub import HfApi |
| from huggingface_hub.utils import RepositoryNotFoundError |
|
|
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| templates_dir = "templates" |
| OUTPUT_DIR = "/tmp/aibom_output" |
| MAX_AGE_DAYS = 7 |
| MAX_FILES = 1000 |
| CLEANUP_INTERVAL = 100 |
|
|
| |
| HF_REPO = "aetheris-ai/aisbom-usage-log" |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| |
|
|
| |
| app = FastAPI(title="AI SBOM Generator API") |
|
|
| |
| try: |
| from src.aibom_generator.rate_limiting import RateLimitMiddleware, ConcurrencyLimitMiddleware, RequestSizeLimitMiddleware |
| logger.info("Successfully imported rate_limiting from src.aibom_generator") |
| except ImportError: |
| try: |
| from .rate_limiting import RateLimitMiddleware, ConcurrencyLimitMiddleware, RequestSizeLimitMiddleware |
| logger.info("Successfully imported rate_limiting with relative import") |
| except ImportError: |
| try: |
| from rate_limiting import RateLimitMiddleware, ConcurrencyLimitMiddleware, RequestSizeLimitMiddleware |
| logger.info("Successfully imported rate_limiting from current directory") |
| except ImportError: |
| logger.error("Could not import rate_limiting, DoS protection disabled") |
| |
| class RateLimitMiddleware(BaseHTTPMiddleware): |
| def __init__(self, app, **kwargs): |
| super().__init__(app) |
| async def dispatch(self, request, call_next): |
| try: |
| return await call_next(request) |
| except Exception as e: |
| logger.error(f"Error in RateLimitMiddleware: {str(e)}") |
| return JSONResponse( |
| status_code=500, |
| content={"detail": f"Internal server error: {str(e)}"} |
| ) |
| |
| class ConcurrencyLimitMiddleware(BaseHTTPMiddleware): |
| def __init__(self, app, **kwargs): |
| super().__init__(app) |
| async def dispatch(self, request, call_next): |
| try: |
| return await call_next(request) |
| except Exception as e: |
| logger.error(f"Error in ConcurrencyLimitMiddleware: {str(e)}") |
| return JSONResponse( |
| status_code=500, |
| content={"detail": f"Internal server error: {str(e)}"} |
| ) |
| |
| class RequestSizeLimitMiddleware(BaseHTTPMiddleware): |
| def __init__(self, app, **kwargs): |
| super().__init__(app) |
| async def dispatch(self, request, call_next): |
| try: |
| return await call_next(request) |
| except Exception as e: |
| logger.error(f"Error in RequestSizeLimitMiddleware: {str(e)}") |
| return JSONResponse( |
| status_code=500, |
| content={"detail": f"Internal server error: {str(e)}"} |
| ) |
| try: |
| from src.aibom_generator.captcha import verify_recaptcha |
| logger.info("Successfully imported captcha from src.aibom_generator") |
| except ImportError: |
| try: |
| from .captcha import verify_recaptcha |
| logger.info("Successfully imported captcha with relative import") |
| except ImportError: |
| try: |
| from captcha import verify_recaptcha |
| logger.info("Successfully imported captcha from current directory") |
| except ImportError: |
| logger.warning("Could not import captcha module, CAPTCHA verification disabled") |
| |
| def verify_recaptcha(response_token=None): |
| logger.warning("Using dummy CAPTCHA verification (always succeeds)") |
| return True |
|
|
|
|
|
|
| |
| app.add_middleware( |
| RateLimitMiddleware, |
| rate_limit_per_minute=10, |
| rate_limit_window=60, |
| protected_routes=["/generate", "/api/generate", "/api/generate-with-report"] |
| ) |
|
|
| app.add_middleware( |
| ConcurrencyLimitMiddleware, |
| max_concurrent_requests=5, |
| timeout=5.0, |
| protected_routes=["/generate", "/api/generate", "/api/generate-with-report"] |
| ) |
|
|
|
|
| |
| app.add_middleware( |
| RequestSizeLimitMiddleware, |
| max_content_length=1024*1024 |
| ) |
|
|
|
|
| |
| class StatusResponse(BaseModel): |
| status: str |
| version: str |
| generator_version: str |
|
|
| |
| templates = Jinja2Templates(directory=templates_dir) |
|
|
| |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output") |
|
|
| |
| request_counter = 0 |
|
|
| |
| try: |
| from src.aibom_generator.cleanup_utils import perform_cleanup |
| logger.info("Successfully imported cleanup_utils") |
| except ImportError: |
| try: |
| from cleanup_utils import perform_cleanup |
| logger.info("Successfully imported cleanup_utils from current directory") |
| except ImportError: |
| logger.error("Could not import cleanup_utils, defining functions inline") |
| |
| def cleanup_old_files(directory, max_age_days=7): |
| """Remove files older than max_age_days from the specified directory.""" |
| if not os.path.exists(directory): |
| logger.warning(f"Directory does not exist: {directory}") |
| return 0 |
| |
| removed_count = 0 |
| now = datetime.now() |
| |
| try: |
| for filename in os.listdir(directory): |
| file_path = os.path.join(directory, filename) |
| if os.path.isfile(file_path): |
| file_age = now - datetime.fromtimestamp(os.path.getmtime(file_path)) |
| if file_age.days > max_age_days: |
| try: |
| os.remove(file_path) |
| removed_count += 1 |
| logger.info(f"Removed old file: {file_path}") |
| except Exception as e: |
| logger.error(f"Error removing file {file_path}: {e}") |
| |
| logger.info(f"Cleanup completed: removed {removed_count} files older than {max_age_days} days from {directory}") |
| return removed_count |
| except Exception as e: |
| logger.error(f"Error during cleanup of directory {directory}: {e}") |
| return 0 |
|
|
| def limit_file_count(directory, max_files=1000): |
| """Ensure no more than max_files are kept in the directory (removes oldest first).""" |
| if not os.path.exists(directory): |
| logger.warning(f"Directory does not exist: {directory}") |
| return 0 |
| |
| removed_count = 0 |
| |
| try: |
| files = [] |
| for filename in os.listdir(directory): |
| file_path = os.path.join(directory, filename) |
| if os.path.isfile(file_path): |
| files.append((file_path, os.path.getmtime(file_path))) |
| |
| |
| files.sort(key=lambda x: x[1]) |
| |
| |
| files_to_remove = files[:-max_files] if len(files) > max_files else [] |
| |
| for file_path, _ in files_to_remove: |
| try: |
| os.remove(file_path) |
| removed_count += 1 |
| logger.info(f"Removed excess file: {file_path}") |
| except Exception as e: |
| logger.error(f"Error removing file {file_path}: {e}") |
| |
| logger.info(f"File count limit enforced: removed {removed_count} oldest files from {directory}, keeping max {max_files}") |
| return removed_count |
| except Exception as e: |
| logger.error(f"Error during file count limiting in directory {directory}: {e}") |
| return 0 |
|
|
| def perform_cleanup(directory, max_age_days=7, max_files=1000): |
| """Perform both time-based and count-based cleanup.""" |
| time_removed = cleanup_old_files(directory, max_age_days) |
| count_removed = limit_file_count(directory, max_files) |
| return time_removed + count_removed |
|
|
| |
| try: |
| removed = perform_cleanup(OUTPUT_DIR, MAX_AGE_DAYS, MAX_FILES) |
| logger.info(f"Initial cleanup removed {removed} files") |
| except Exception as e: |
| logger.error(f"Error during initial cleanup: {e}") |
|
|
| |
| @app.middleware("http" ) |
| async def cleanup_middleware(request, call_next): |
| """Middleware to periodically run cleanup.""" |
| global request_counter |
| |
| |
| request_counter += 1 |
| |
| |
| if request_counter % CLEANUP_INTERVAL == 0: |
| logger.info(f"Running scheduled cleanup after {request_counter} requests") |
| try: |
| removed = perform_cleanup(OUTPUT_DIR, MAX_AGE_DAYS, MAX_FILES) |
| logger.info(f"Scheduled cleanup removed {removed} files") |
| except Exception as e: |
| logger.error(f"Error during scheduled cleanup: {e}") |
| |
| |
| response = await call_next(request) |
| return response |
|
|
|
|
| |
| |
| |
| HF_ID_REGEX = re.compile(r"^[a-zA-Z0-9\.\-\_]+/[a-zA-Z0-9\.\-\_]+$") |
|
|
| def is_valid_hf_input(input_str: str) -> bool: |
| """Checks if the input is a valid Hugging Face model ID or URL.""" |
| if not input_str or len(input_str) > 200: |
| return False |
| |
| if input_str.startswith(("http://", "https://") ): |
| try: |
| parsed = urlparse(input_str) |
| |
| if parsed.netloc == "huggingface.co": |
| path_parts = parsed.path.strip("/").split("/") |
| |
| if len(path_parts) >= 2 and path_parts[0] and path_parts[1]: |
| |
| if re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[0]) and \ |
| re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[1]): |
| return True |
| return False |
| except Exception: |
| return False |
| else: |
| |
| return bool(HF_ID_REGEX.match(input_str)) |
|
|
| def _normalise_model_id(raw_id: str) -> str: |
| """ |
| Accept either validated 'owner/model' or a validated full URL like |
| 'https://huggingface.co/owner/model'. Return 'owner/model'. |
| Assumes input has already been validated by is_valid_hf_input. |
| """ |
| if raw_id.startswith(("http://", "https://") ): |
| path = urlparse(raw_id).path.lstrip("/") |
| parts = path.split("/") |
| |
| return f"{parts[0]}/{parts[1]}" |
| return raw_id |
|
|
| |
|
|
|
|
| |
| def log_sbom_generation(model_id: str): |
| """Logs a successful SBOM generation event to the Hugging Face dataset.""" |
| if not HF_TOKEN: |
| logger.warning("HF_TOKEN not set. Skipping SBOM generation logging.") |
| return |
|
|
| try: |
| |
| normalized_model_id_for_log = _normalise_model_id(model_id) |
| log_data = { |
| "timestamp": [datetime.utcnow().isoformat()], |
| "event": ["generated"], |
| "model_id": [normalized_model_id_for_log] |
| } |
| ds_new_log = Dataset.from_dict(log_data) |
|
|
| |
| try: |
| |
| |
| existing_ds = load_dataset(HF_REPO, token=HF_TOKEN, split='train', trust_remote_code=True) |
| |
| if len(existing_ds) > 0 and set(existing_ds.column_names) == set(log_data.keys()): |
| ds_to_push = concatenate_datasets([existing_ds, ds_new_log]) |
| elif len(existing_ds) == 0: |
| logger.info(f"Dataset {HF_REPO} is empty. Pushing initial data.") |
| ds_to_push = ds_new_log |
| else: |
| logger.warning(f"Dataset {HF_REPO} has unexpected columns {existing_ds.column_names} vs {list(log_data.keys())}. Appending new log anyway, structure might differ.") |
| |
| |
| ds_to_push = concatenate_datasets([existing_ds, ds_new_log]) |
|
|
| except Exception as load_err: |
| |
| |
| logger.info(f"Could not load existing dataset {HF_REPO} (may not exist yet): {load_err}. Pushing new dataset.") |
| ds_to_push = ds_new_log |
|
|
| |
| |
| ds_to_push.push_to_hub(HF_REPO, token=HF_TOKEN, private=True) |
| logger.info(f"Successfully logged SBOM generation for {normalized_model_id_for_log} to {HF_REPO}") |
|
|
| except Exception as e: |
| logger.error(f"Failed to log SBOM generation to {HF_REPO}: {e}") |
|
|
| def get_sbom_count() -> str: |
| """Retrieves the total count of generated SBOMs from the Hugging Face dataset.""" |
| if not HF_TOKEN: |
| logger.warning("HF_TOKEN not set. Cannot retrieve SBOM count.") |
| return "N/A" |
| try: |
| |
| |
| |
| ds = load_dataset(HF_REPO, token=HF_TOKEN, split='train', trust_remote_code=True) |
| count = len(ds) |
| logger.info(f"Retrieved SBOM count: {count} from {HF_REPO}") |
| |
| return f"{count:,}" |
| except Exception as e: |
| logger.error(f"Failed to retrieve SBOM count from {HF_REPO}: {e}") |
| |
| return "N/A" |
| |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| logger.info(f"Output directory ready at {OUTPUT_DIR}") |
| logger.info(f"Registered routes: {[route.path for route in app.routes]}") |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def root(request: Request): |
| sbom_count = get_sbom_count() |
| try: |
| return templates.TemplateResponse("index.html", {"request": request, "sbom_count": sbom_count}) |
| except Exception as e: |
| logger.error(f"Error rendering template: {str(e)}") |
| |
| try: |
| return templates.TemplateResponse("error.html", {"request": request, "error": f"Template rendering error: {str(e)}", "sbom_count": sbom_count}) |
| except Exception as template_err: |
| |
| logger.error(f"Error rendering error template: {template_err}") |
| raise HTTPException(status_code=500, detail=f"Template rendering error: {str(e)}") |
|
|
| @app.get("/status", response_model=StatusResponse) |
| async def get_status(): |
| return StatusResponse(status="operational", version="1.0.0", generator_version="1.0.0") |
|
|
| |
| def import_utils(): |
| """Import utils module with fallback paths.""" |
| try: |
| |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| |
| try: |
| from utils import calculate_completeness_score |
| logger.info("Imported utils.calculate_completeness_score directly") |
| return calculate_completeness_score |
| except ImportError: |
| pass |
|
|
| |
| try: |
| from src.aibom_generator.utils import calculate_completeness_score |
| logger.info("Imported src.aibom_generator.utils.calculate_completeness_score") |
| return calculate_completeness_score |
| except ImportError: |
| pass |
|
|
| |
| try: |
| from aibom_generator.utils import calculate_completeness_score |
| logger.info("Imported aibom_generator.utils.calculate_completeness_score") |
| return calculate_completeness_score |
| except ImportError: |
| pass |
|
|
| |
| logger.warning("Could not import calculate_completeness_score, using default implementation") |
| return None |
| except Exception as e: |
| logger.error(f"Error importing utils: {str(e)}") |
| return None |
|
|
| |
| calculate_completeness_score = import_utils() |
|
|
| |
| def create_comprehensive_completeness_score(aibom=None): |
| """ |
| Create a comprehensive completeness_score object with all required attributes. |
| If aibom is provided and calculate_completeness_score is available, use it to calculate the score. |
| Otherwise, return a default score structure. |
| """ |
| |
| if calculate_completeness_score and aibom: |
| try: |
| return calculate_completeness_score(aibom, validate=True, use_best_practices=True) |
| except Exception as e: |
| logger.error(f"Error calculating completeness score: {str(e)}") |
|
|
| |
| return { |
| "total_score": 75.5, |
| "section_scores": { |
| "required_fields": 20, |
| "metadata": 15, |
| "component_basic": 18, |
| "component_model_card": 15, |
| "external_references": 7.5 |
| }, |
| "max_scores": { |
| "required_fields": 20, |
| "metadata": 20, |
| "component_basic": 20, |
| "component_model_card": 30, |
| "external_references": 10 |
| }, |
| "field_checklist": { |
| |
| "bomFormat": "✔ ★★★", |
| "specVersion": "✔ ★★★", |
| "serialNumber": "✔ ★★★", |
| "version": "✔ ★★★", |
| "metadata.timestamp": "✔ ★★", |
| "metadata.tools": "✔ ★★", |
| "metadata.authors": "✔ ★★", |
| "metadata.component": "✔ ★★", |
|
|
| |
| "component.type": "✔ ★★", |
| "component.name": "✔ ★★★", |
| "component.bom-ref": "✔ ★★", |
| "component.purl": "✔ ★★", |
| "component.description": "✔ ★★", |
| "component.licenses": "✔ ★★", |
|
|
| |
| "modelCard.modelParameters": "✔ ★★", |
| "modelCard.quantitativeAnalysis": "✘ ★★", |
| "modelCard.considerations": "✔ ★★", |
|
|
| |
| "externalReferences": "✔ ★", |
|
|
| |
| "name": "✔ ★★★", |
| "downloadLocation": "✔ ★★★", |
| "primaryPurpose": "✔ ★★★", |
| "suppliedBy": "✔ ★★★", |
| "energyConsumption": "✘ ★★", |
| "hyperparameter": "✔ ★★", |
| "limitation": "✔ ★★", |
| "safetyRiskAssessment": "✘ ★★", |
| "typeOfModel": "✔ ★★", |
| "modelExplainability": "✘ ★", |
| "standardCompliance": "✘ ★", |
| "domain": "✔ ★", |
| "energyQuantity": "✘ ★", |
| "energyUnit": "✘ ★", |
| "informationAboutTraining": "✔ ★", |
| "informationAboutApplication": "✔ ★", |
| "metric": "✘ ★", |
| "metricDecisionThreshold": "✘ ★", |
| "modelDataPreprocessing": "✘ ★", |
| "autonomyType": "✘ ★", |
| "useSensitivePersonalInformation": "✘ ★" |
| }, |
| "field_tiers": { |
| |
| "bomFormat": "critical", |
| "specVersion": "critical", |
| "serialNumber": "critical", |
| "version": "critical", |
| "metadata.timestamp": "important", |
| "metadata.tools": "important", |
| "metadata.authors": "important", |
| "metadata.component": "important", |
|
|
| |
| "component.type": "important", |
| "component.name": "critical", |
| "component.bom-ref": "important", |
| "component.purl": "important", |
| "component.description": "important", |
| "component.licenses": "important", |
|
|
| |
| "modelCard.modelParameters": "important", |
| "modelCard.quantitativeAnalysis": "important", |
| "modelCard.considerations": "important", |
|
|
| |
| "externalReferences": "supplementary", |
|
|
| |
| "name": "critical", |
| "downloadLocation": "critical", |
| "primaryPurpose": "critical", |
| "suppliedBy": "critical", |
| "energyConsumption": "important", |
| "hyperparameter": "important", |
| "limitation": "important", |
| "safetyRiskAssessment": "important", |
| "typeOfModel": "important", |
| "modelExplainability": "supplementary", |
| "standardCompliance": "supplementary", |
| "domain": "supplementary", |
| "energyQuantity": "supplementary", |
| "energyUnit": "supplementary", |
| "informationAboutTraining": "supplementary", |
| "informationAboutApplication": "supplementary", |
| "metric": "supplementary", |
| "metricDecisionThreshold": "supplementary", |
| "modelDataPreprocessing": "supplementary", |
| "autonomyType": "supplementary", |
| "useSensitivePersonalInformation": "supplementary" |
| }, |
| "missing_fields": { |
| "critical": [], |
| "important": ["modelCard.quantitativeAnalysis", "energyConsumption", "safetyRiskAssessment"], |
| "supplementary": ["modelExplainability", "standardCompliance", "energyQuantity", "energyUnit", |
| "metric", "metricDecisionThreshold", "modelDataPreprocessing", |
| "autonomyType", "useSensitivePersonalInformation"] |
| }, |
| "completeness_profile": { |
| "name": "standard", |
| "description": "Comprehensive fields for proper documentation", |
| "satisfied": True |
| }, |
| "penalty_applied": False, |
| "penalty_reason": None, |
| "recommendations": [ |
| { |
| "priority": "medium", |
| "field": "modelCard.quantitativeAnalysis", |
| "message": "Missing important field: modelCard.quantitativeAnalysis", |
| "recommendation": "Add quantitative analysis information to the model card" |
| }, |
| { |
| "priority": "medium", |
| "field": "energyConsumption", |
| "message": "Missing important field: energyConsumption - helpful for environmental impact assessment", |
| "recommendation": "Consider documenting energy consumption metrics for better transparency" |
| }, |
| { |
| "priority": "medium", |
| "field": "safetyRiskAssessment", |
| "message": "Missing important field: safetyRiskAssessment", |
| "recommendation": "Add safety risk assessment information to improve documentation" |
| } |
| ] |
| } |
|
|
| @app.post("/generate", response_class=HTMLResponse) |
| async def generate_form( |
| request: Request, |
| model_id: str = Form(...), |
| include_inference: bool = Form(False), |
| use_best_practices: bool = Form(True), |
| g_recaptcha_response: Optional[str] = Form(None) |
| ): |
| |
| form_data = await request.form() |
| logger.info(f"All form data: {dict(form_data)}") |
| |
| |
| if not verify_recaptcha(g_recaptcha_response): |
| return templates.TemplateResponse( |
| "error.html", { |
| "request": request, |
| "error": "Security verification failed. Please try again.", |
| "sbom_count": get_sbom_count() |
| } |
| ) |
| |
| sbom_count = get_sbom_count() |
| |
| |
| sanitized_model_id = html.escape(model_id) |
| |
| |
| if not is_valid_hf_input(sanitized_model_id): |
| error_message = "Invalid input format. Please provide a valid Hugging Face model ID (e.g., 'owner/model') or a full model URL (e.g., 'https://huggingface.co/owner/model') ." |
| logger.warning(f"Invalid model input format received: {model_id}") |
| |
| return templates.TemplateResponse( |
| "error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": sanitized_model_id} |
| ) |
| |
| |
| normalized_model_id = _normalise_model_id(sanitized_model_id) |
| |
| |
| try: |
| hf_api = HfApi() |
| logger.info(f"Attempting to fetch model info for: {normalized_model_id}") |
| model_info = hf_api.model_info(normalized_model_id) |
| logger.info(f"Successfully fetched model info for: {normalized_model_id}") |
| except RepositoryNotFoundError: |
| error_message = f"Error: The provided ID \"{normalized_model_id}\" could not be found on Hugging Face or does not correspond to a model repository." |
| logger.warning(f"Repository not found for ID: {normalized_model_id}") |
| return templates.TemplateResponse( |
| "error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id} |
| ) |
| except Exception as api_err: |
| error_message = f"Error verifying model ID with Hugging Face API: {str(api_err)}" |
| logger.error(f"HF API error for {normalized_model_id}: {str(api_err)}") |
| return templates.TemplateResponse( |
| "error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id} |
| ) |
| |
|
|
| |
| |
| try: |
| |
| generator = None |
| try: |
| from src.aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| try: |
| from aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| try: |
| from generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| logger.error("Could not import AIBOMGenerator from any known location") |
| raise ImportError("Could not import AIBOMGenerator from any known location") |
|
|
| |
| aibom = generator.generate_aibom( |
| model_id=sanitized_model_id, |
| include_inference=include_inference, |
| use_best_practices=use_best_practices |
| ) |
| enhancement_report = generator.get_enhancement_report() |
|
|
| |
| |
| |
| filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json" |
| filepath = os.path.join(OUTPUT_DIR, filename) |
|
|
| with open(filepath, "w") as f: |
| json.dump(aibom, f, indent=2) |
|
|
| |
| log_sbom_generation(sanitized_model_id) |
| sbom_count = get_sbom_count() |
| |
|
|
| download_url = f"/output/{filename}" |
|
|
| |
| download_script = f""" |
| <script> |
| function downloadJSON() {{ |
| const a = document.createElement('a'); |
| a.href = '{download_url}'; |
| a.download = '{filename}'; |
| document.body.appendChild(a); |
| a.click(); |
| document.body.removeChild(a); |
| }} |
| |
| function switchTab(tabId) {{ |
| // Hide all tabs |
| document.querySelectorAll('.tab-content').forEach(tab => {{ |
| tab.classList.remove('active'); |
| }}); |
| |
| // Deactivate all tab buttons |
| document.querySelectorAll('.aibom-tab').forEach(button => {{ |
| button.classList.remove('active'); |
| }}); |
| |
| // Show the selected tab |
| document.getElementById(tabId).classList.add('active'); |
| |
| // Activate the clicked button |
| event.currentTarget.classList.add('active'); |
| }} |
| |
| function toggleCollapsible(element) {{ |
| element.classList.toggle('active'); |
| var content = element.nextElementSibling; |
| if (content.style.maxHeight) {{ |
| content.style.maxHeight = null; |
| content.classList.remove('active'); |
| }} else {{ |
| content.style.maxHeight = content.scrollHeight + "px"; |
| content.classList.add('active'); |
| }} |
| }} |
| </script> |
| """ |
|
|
| |
| |
| completeness_score = None |
| if hasattr(generator, 'get_completeness_score'): |
| try: |
| completeness_score = generator.get_completeness_score(sanitized_model_id) |
| logger.info("Successfully retrieved completeness_score from generator") |
| except Exception as e: |
| logger.error(f"Completeness score error from generator: {str(e)}") |
|
|
| |
| if completeness_score is None or not isinstance(completeness_score, dict) or 'field_checklist' not in completeness_score: |
| logger.info("Using comprehensive completeness_score with field_checklist") |
| completeness_score = create_comprehensive_completeness_score(aibom) |
|
|
| |
| if enhancement_report is None: |
| enhancement_report = { |
| "ai_enhanced": False, |
| "ai_model": None, |
| "original_score": {"total_score": 0, "completeness_score": 0}, |
| "final_score": {"total_score": 0, "completeness_score": 0}, |
| "improvement": 0 |
| } |
| else: |
| |
| if "original_score" not in enhancement_report or enhancement_report["original_score"] is None: |
| enhancement_report["original_score"] = {"total_score": 0, "completeness_score": 0} |
| elif "completeness_score" not in enhancement_report["original_score"]: |
| enhancement_report["original_score"]["completeness_score"] = enhancement_report["original_score"].get("total_score", 0) |
|
|
| |
| if "final_score" not in enhancement_report or enhancement_report["final_score"] is None: |
| enhancement_report["final_score"] = {"total_score": 0, "completeness_score": 0} |
| elif "completeness_score" not in enhancement_report["final_score"]: |
| enhancement_report["final_score"]["completeness_score"] = enhancement_report["final_score"].get("total_score", 0) |
|
|
| |
| display_names = { |
| "required_fields": "Required Fields", |
| "metadata": "Metadata", |
| "component_basic": "Component Basic Info", |
| "component_model_card": "Model Card", |
| "external_references": "External References" |
| } |
|
|
| tooltips = { |
| "required_fields": "Basic required fields for a valid AIBOM", |
| "metadata": "Information about the AIBOM itself", |
| "component_basic": "Basic information about the AI model component", |
| "component_model_card": "Detailed model card information", |
| "external_references": "Links to external resources" |
| } |
|
|
| weights = { |
| "required_fields": 20, |
| "metadata": 20, |
| "component_basic": 20, |
| "component_model_card": 30, |
| "external_references": 10 |
| } |
|
|
| |
| return templates.TemplateResponse( |
| "result.html", |
| { |
| "request": request, |
| "model_id": normalized_model_id, |
| "aibom": aibom, |
| "enhancement_report": enhancement_report, |
| "completeness_score": completeness_score, |
| "download_url": download_url, |
| "download_script": download_script, |
| "display_names": display_names, |
| "tooltips": tooltips, |
| "weights": weights, |
| "sbom_count": sbom_count, |
| "display_names": display_names, |
| "tooltips": tooltips, |
| "weights": weights |
| } |
| ) |
| |
| except Exception as e: |
| logger.error(f"Error generating AI SBOM: {str(e)}") |
| sbom_count = get_sbom_count() |
| |
| return templates.TemplateResponse( |
| "error.html", {"request": request, "error": str(e), "sbom_count": sbom_count, "model_id": normalized_model_id} |
| ) |
|
|
| @app.get("/download/{filename}") |
| async def download_file(filename: str): |
| """ |
| Download a generated AIBOM file. |
| |
| This endpoint serves the generated AIBOM JSON files for download. |
| """ |
| file_path = os.path.join(OUTPUT_DIR, filename) |
| if not os.path.exists(file_path): |
| raise HTTPException(status_code=404, detail="File not found") |
|
|
| return FileResponse( |
| file_path, |
| media_type="application/json", |
| filename=filename |
| ) |
|
|
| |
| class GenerateRequest(BaseModel): |
| model_id: str |
| include_inference: bool = True |
| use_best_practices: bool = True |
| hf_token: Optional[str] = None |
|
|
| @app.post("/api/generate") |
| async def api_generate_aibom(request: GenerateRequest): |
| """ |
| Generate an AI SBOM for a specified Hugging Face model. |
| |
| This endpoint accepts JSON input and returns JSON output. |
| """ |
| try: |
| |
| sanitized_model_id = html.escape(request.model_id) |
| if not is_valid_hf_input(sanitized_model_id): |
| raise HTTPException(status_code=400, detail="Invalid model ID format") |
| |
| normalized_model_id = _normalise_model_id(sanitized_model_id) |
| |
| |
| try: |
| hf_api = HfApi() |
| model_info = hf_api.model_info(normalized_model_id) |
| except RepositoryNotFoundError: |
| raise HTTPException(status_code=404, detail=f"Model {normalized_model_id} not found on Hugging Face") |
| except Exception as api_err: |
| raise HTTPException(status_code=500, detail=f"Error verifying model: {str(api_err)}") |
| |
| |
| try: |
| |
| generator = None |
| try: |
| from src.aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| try: |
| from aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| try: |
| from generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| raise HTTPException(status_code=500, detail="Could not import AIBOMGenerator") |
| |
| aibom = generator.generate_aibom( |
| model_id=sanitized_model_id, |
| include_inference=request.include_inference, |
| use_best_practices=request.use_best_practices |
| ) |
| enhancement_report = generator.get_enhancement_report() |
| |
| |
| filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json" |
| filepath = os.path.join(OUTPUT_DIR, filename) |
| with open(filepath, "w") as f: |
| json.dump(aibom, f, indent=2) |
| |
| |
| log_sbom_generation(sanitized_model_id) |
| |
| |
| return { |
| "aibom": aibom, |
| "model_id": normalized_model_id, |
| "generated_at": datetime.utcnow().isoformat() + "Z", |
| "request_id": str(uuid.uuid4()), |
| "download_url": f"/output/{filename}" |
| } |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
|
|
| @app.post("/api/generate-with-report") |
| async def api_generate_with_report(request: GenerateRequest): |
| """ |
| Generate an AI SBOM with a completeness report. |
| This endpoint accepts JSON input and returns JSON output with completeness score. |
| """ |
| try: |
| |
| sanitized_model_id = html.escape(request.model_id) |
| if not is_valid_hf_input(sanitized_model_id): |
| raise HTTPException(status_code=400, detail="Invalid model ID format") |
| |
| normalized_model_id = _normalise_model_id(sanitized_model_id) |
| |
| |
| try: |
| hf_api = HfApi() |
| model_info = hf_api.model_info(normalized_model_id) |
| except RepositoryNotFoundError: |
| raise HTTPException(status_code=404, detail=f"Model {normalized_model_id} not found on Hugging Face") |
| except Exception as api_err: |
| raise HTTPException(status_code=500, detail=f"Error verifying model: {str(api_err)}") |
| |
| |
| try: |
| |
| generator = None |
| try: |
| from src.aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| try: |
| from aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| try: |
| from generator import AIBOMGenerator |
| generator = AIBOMGenerator() |
| except ImportError: |
| raise HTTPException(status_code=500, detail="Could not import AIBOMGenerator") |
| |
| aibom = generator.generate_aibom( |
| model_id=sanitized_model_id, |
| include_inference=request.include_inference, |
| use_best_practices=request.use_best_practices |
| ) |
| |
| |
| completeness_score = calculate_completeness_score(aibom, validate=True, use_best_practices=request.use_best_practices) |
| |
| |
| for section, score in completeness_score["section_scores"].items(): |
| if isinstance(score, float) and not score.is_integer(): |
| completeness_score["section_scores"][section] = round(score, 1) |
| |
| |
| if "field_checklist" in completeness_score: |
| machine_parseable_checklist = {} |
| for field, value in completeness_score["field_checklist"].items(): |
| |
| present = "present" if "✔" in value else "missing" |
| |
| |
| importance = completeness_score["field_tiers"].get(field, "unknown") |
| |
| |
| machine_parseable_checklist[field] = { |
| "status": present, |
| "importance": importance |
| } |
| |
| |
| completeness_score["field_checklist"] = machine_parseable_checklist |
| |
| |
| completeness_score.pop("field_tiers", None) |
| |
| |
| filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json" |
| filepath = os.path.join(OUTPUT_DIR, filename) |
| with open(filepath, "w") as f: |
| json.dump(aibom, f, indent=2) |
| |
| |
| log_sbom_generation(sanitized_model_id) |
| |
| |
| return { |
| "aibom": aibom, |
| "model_id": normalized_model_id, |
| "generated_at": datetime.utcnow().isoformat() + "Z", |
| "request_id": str(uuid.uuid4()), |
| "download_url": f"/output/{filename}", |
| "completeness_score": completeness_score |
| } |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
|
|
|
|
| @app.get("/api/models/{model_id:path}/score" ) |
| async def get_model_score( |
| model_id: str, |
| hf_token: Optional[str] = None, |
| use_best_practices: bool = True |
| ): |
| """ |
| Get the completeness score for a model without generating a full AIBOM. |
| """ |
| try: |
| |
| sanitized_model_id = html.escape(model_id) |
| if not is_valid_hf_input(sanitized_model_id): |
| raise HTTPException(status_code=400, detail="Invalid model ID format") |
| |
| normalized_model_id = _normalise_model_id(sanitized_model_id) |
| |
| |
| try: |
| hf_api = HfApi(token=hf_token) |
| model_info = hf_api.model_info(normalized_model_id) |
| except RepositoryNotFoundError: |
| raise HTTPException(status_code=404, detail=f"Model {normalized_model_id} not found on Hugging Face") |
| except Exception as api_err: |
| raise HTTPException(status_code=500, detail=f"Error verifying model: {str(api_err)}") |
| |
| |
| try: |
| |
| generator = None |
| try: |
| from src.aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator(hf_token=hf_token) |
| except ImportError: |
| try: |
| from aibom_generator.generator import AIBOMGenerator |
| generator = AIBOMGenerator(hf_token=hf_token) |
| except ImportError: |
| try: |
| from generator import AIBOMGenerator |
| generator = AIBOMGenerator(hf_token=hf_token) |
| except ImportError: |
| raise HTTPException(status_code=500, detail="Could not import AIBOMGenerator") |
| |
| |
| aibom = generator.generate_aibom( |
| model_id=sanitized_model_id, |
| include_inference=False, |
| use_best_practices=use_best_practices |
| ) |
| |
| |
| score = calculate_completeness_score(aibom, validate=True, use_best_practices=use_best_practices) |
|
|
| |
| log_sbom_generation(normalized_model_id) |
| |
| |
| for section, value in score["section_scores"].items(): |
| if isinstance(value, float) and not value.is_integer(): |
| score["section_scores"][section] = round(value, 1) |
| |
| |
| return { |
| "total_score": score["total_score"], |
| "section_scores": score["section_scores"], |
| "max_scores": score["max_scores"] |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error calculating model score: {str(e)}") |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") |
|
|
|
|
| |
| class BatchRequest(BaseModel): |
| model_ids: List[str] |
| include_inference: bool = True |
| use_best_practices: bool = True |
| hf_token: Optional[str] = None |
|
|
| |
| batch_jobs = {} |
|
|
| @app.post("/api/batch") |
| async def batch_generate(request: BatchRequest): |
| """ |
| Start a batch job to generate AIBOMs for multiple models. |
| """ |
| try: |
| |
| valid_model_ids = [] |
| for model_id in request.model_ids: |
| sanitized_id = html.escape(model_id) |
| if is_valid_hf_input(sanitized_id): |
| valid_model_ids.append(sanitized_id) |
| else: |
| logger.warning(f"Skipping invalid model ID: {model_id}") |
| |
| if not valid_model_ids: |
| raise HTTPException(status_code=400, detail="No valid model IDs provided") |
| |
| |
| job_id = str(uuid.uuid4()) |
| created_at = datetime.utcnow() |
| |
| |
| batch_jobs[job_id] = { |
| "job_id": job_id, |
| "status": "queued", |
| "model_ids": valid_model_ids, |
| "created_at": created_at.isoformat() + "Z", |
| "completed": 0, |
| "total": len(valid_model_ids), |
| "results": {} |
| } |
| |
| |
| batch_jobs[job_id]["status"] = "processing" |
| |
| return { |
| "job_id": job_id, |
| "status": "queued", |
| "model_ids": valid_model_ids, |
| "created_at": created_at.isoformat() + "Z" |
| } |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error creating batch job: {str(e)}") |
|
|
| @app.get("/api/batch/{job_id}") |
| async def get_batch_status(job_id: str): |
| """ |
| Check the status of a batch job. |
| """ |
| if job_id not in batch_jobs: |
| raise HTTPException(status_code=404, detail=f"Batch job {job_id} not found") |
| |
| return batch_jobs[job_id] |
|
|
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| |
| if not HF_TOKEN: |
| print("Warning: HF_TOKEN environment variable not set. SBOM count will show N/A and logging will be skipped.") |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|