Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from optimum.neuron import utils | |
| import logging | |
| import sys | |
| import os | |
| import httpx | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Get the absolute path to the static directory | |
| static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") | |
| logger.info(f"Static directory path: {static_dir}") | |
| # Get the absolute path to the templates directory | |
| templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") | |
| logger.info(f"Templates directory path: {templates_dir}") | |
| # Mount static files and templates | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| templates = Jinja2Templates(directory=templates_dir) | |
| async def health_check(): | |
| logger.info("Health check endpoint called") | |
| return {"status": "healthy"} | |
| async def home(request: Request): | |
| logger.info("Home page requested") | |
| # Check if we're running in Spaces | |
| is_spaces = os.getenv("SPACE_ID") is not None | |
| # Use HTTPS only for Spaces, otherwise use the request's protocol | |
| base_url = str(request.base_url) | |
| if is_spaces: | |
| base_url = base_url.replace("http://", "https://") | |
| return templates.TemplateResponse( | |
| "index.html", | |
| { | |
| "request": request, | |
| "base_url": base_url | |
| } | |
| ) | |
| async def get_model_list(): | |
| logger.info("Fetching model list") | |
| try: | |
| # Add debug logging | |
| logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}") | |
| model_list = utils.get_hub_cached_models(mode="inference") | |
| logger.info(f"Found {len(model_list)} models") | |
| models = [] | |
| seen_models = set() | |
| for model_tuple in model_list: | |
| architecture, org, model_id = model_tuple | |
| full_model_id = f"{org}/{model_id}" | |
| if full_model_id not in seen_models: | |
| models.append({ | |
| "id": full_model_id, | |
| "name": full_model_id, | |
| "type": architecture | |
| }) | |
| seen_models.add(full_model_id) | |
| logger.info(f"Returning {len(models)} unique models") | |
| return JSONResponse(content=models) | |
| except Exception as e: | |
| # Enhanced error logging | |
| logger.error(f"Error fetching models: {str(e)}") | |
| logger.error("Full error details:", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e), "type": str(type(e).__name__)} | |
| ) | |
| async def get_model_info_endpoint(model_id: str): | |
| logger.info(f"Fetching configurations for model: {model_id}") | |
| try: | |
| # Define the base URL for the HuggingFace API | |
| base_url = "https://huggingface.co/api/integrations/aws/v1/lookup" | |
| api_url = f"{base_url}/{model_id}" | |
| # Make async HTTP request with timeout | |
| timeout = httpx.Timeout(15.0, connect=5.0) # 10s for entire request, 5s for connection | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| response = await client.get(api_url) | |
| response.raise_for_status() | |
| data = response.json() | |
| configs = data.get("cached_configs", []) | |
| logger.info(f"Found {len(configs)} configurations for model {model_id}") | |
| return JSONResponse(content={"configurations": configs}) | |
| except httpx.TimeoutException as e: | |
| logger.error(f"Timeout while fetching configurations for model {model_id}: {str(e)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=504, # Gateway Timeout | |
| content={"error": "Request timed out while fetching model configurations"} | |
| ) | |
| except httpx.HTTPError as e: | |
| logger.error(f"HTTP error fetching configurations for model {model_id}: {str(e)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Failed to fetch model configurations: {str(e)}"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e)} | |
| ) | |
| async def static_files(path: str, request: Request): | |
| logger.info(f"Static file requested: {path}") | |
| file_path = os.path.join(static_dir, path) | |
| if os.path.exists(file_path): | |
| response = FileResponse(file_path) | |
| # Ensure proper content type | |
| if path.endswith('.css'): | |
| response.headers["content-type"] = "text/css" | |
| elif path.endswith('.js'): | |
| response.headers["content-type"] = "application/javascript" | |
| return response | |
| return JSONResponse(status_code=404, content={"error": "File not found"}) |