Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| FastAPI Endpoint for GASM-LLM Integration | |
| This module provides a FastAPI endpoint that can be used with OpenAI's CustomGPT | |
| to access GASM-enhanced language processing capabilities. | |
| """ | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from typing import Dict, List, Optional, Any, Union | |
| import torch | |
| import logging | |
| import asyncio | |
| from datetime import datetime | |
| import json | |
| import os | |
| from contextlib import asynccontextmanager | |
| from gasm_llm_layer import GASMEnhancedLLM, GASMTokenEmbedding | |
| from gasm.utils import check_se3_invariance | |
| from gasm.core import GASM | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global model instance | |
| model_instance = None | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Lifespan manager for FastAPI app | |
| """ | |
| global model_instance | |
| # Startup | |
| logger.info("Loading GASM-LLM model...") | |
| try: | |
| model_instance = GASMEnhancedLLM( | |
| base_model_name="distilbert-base-uncased", | |
| gasm_hidden_dim=256, | |
| gasm_output_dim=128, | |
| enable_geometry=True | |
| ) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| model_instance = None | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down...") | |
| model_instance = None | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="GASM-LLM API", | |
| description="API for GASM-enhanced Large Language Model processing", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models for request/response | |
| class TextProcessingRequest(BaseModel): | |
| """Request model for text processing""" | |
| text: str = Field(..., description="Text to process", min_length=1, max_length=10000) | |
| enable_geometry: bool = Field(True, description="Enable geometric processing") | |
| return_embeddings: bool = Field(False, description="Return raw embeddings") | |
| return_geometry: bool = Field(False, description="Return geometric information") | |
| max_length: int = Field(512, description="Maximum sequence length", ge=1, le=2048) | |
| model_config: Optional[Dict[str, Any]] = Field(None, description="Model configuration overrides") | |
| class GeometricAnalysisRequest(BaseModel): | |
| """Request model for geometric analysis""" | |
| text: str = Field(..., description="Text to analyze geometrically") | |
| analysis_type: str = Field("full", description="Type of analysis: 'full', 'curvature', 'invariance'") | |
| num_invariance_tests: int = Field(10, description="Number of invariance tests", ge=1, le=100) | |
| tolerance: float = Field(1e-3, description="Tolerance for invariance tests", ge=1e-6, le=1e-1) | |
| class ComparisonRequest(BaseModel): | |
| """Request model for comparing geometric vs standard processing""" | |
| text: str = Field(..., description="Text to compare") | |
| metrics: List[str] = Field(["embedding_norm", "attention_patterns", "geometric_consistency"], | |
| description="Metrics to compare") | |
| class BatchProcessingRequest(BaseModel): | |
| """Request model for batch processing""" | |
| texts: List[str] = Field(..., description="List of texts to process", min_items=1, max_items=100) | |
| enable_geometry: bool = Field(True, description="Enable geometric processing") | |
| return_summary: bool = Field(True, description="Return summary statistics") | |
| class TextProcessingResponse(BaseModel): | |
| """Response model for text processing""" | |
| success: bool | |
| timestamp: str | |
| processing_time: float | |
| text_length: int | |
| model_info: Dict[str, Any] | |
| embedding_stats: Dict[str, float] | |
| geometric_stats: Optional[Dict[str, Any]] = None | |
| embeddings: Optional[List[List[float]]] = None | |
| geometric_info: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| class GeometricAnalysisResponse(BaseModel): | |
| """Response model for geometric analysis""" | |
| success: bool | |
| timestamp: str | |
| analysis_type: str | |
| curvature_analysis: Optional[Dict[str, Any]] = None | |
| invariance_results: Optional[Dict[str, Any]] = None | |
| geometric_properties: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| class ComparisonResponse(BaseModel): | |
| """Response model for comparison""" | |
| success: bool | |
| timestamp: str | |
| geometric_results: Dict[str, Any] | |
| standard_results: Dict[str, Any] | |
| comparison_metrics: Dict[str, Any] | |
| error: Optional[str] = None | |
| class BatchProcessingResponse(BaseModel): | |
| """Response model for batch processing""" | |
| success: bool | |
| timestamp: str | |
| num_texts: int | |
| processing_times: List[float] | |
| batch_summary: Dict[str, Any] | |
| individual_results: Optional[List[Dict[str, Any]]] = None | |
| error: Optional[str] = None | |
| class HealthResponse(BaseModel): | |
| """Response model for health check""" | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| memory_usage: Dict[str, Any] | |
| uptime: str | |
| def get_model(): | |
| """ | |
| Dependency to get the model instance | |
| """ | |
| global model_instance | |
| if model_instance is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| return model_instance | |
| async def root(): | |
| """ | |
| Root endpoint | |
| """ | |
| return { | |
| "message": "GASM-LLM API", | |
| "version": "1.0.0", | |
| "description": "API for GASM-enhanced Large Language Model processing", | |
| "endpoints": { | |
| "process": "POST /process - Process text with geometric enhancement", | |
| "analyze": "POST /analyze - Perform geometric analysis", | |
| "compare": "POST /compare - Compare geometric vs standard processing", | |
| "batch": "POST /batch - Process multiple texts", | |
| "health": "GET /health - Health check", | |
| "info": "GET /info - Model information" | |
| } | |
| } | |
| async def health_check(): | |
| """ | |
| Health check endpoint | |
| """ | |
| global model_instance | |
| # Check memory usage | |
| memory_info = {} | |
| if torch.cuda.is_available(): | |
| memory_info["gpu_memory"] = { | |
| "allocated": torch.cuda.memory_allocated(), | |
| "reserved": torch.cuda.memory_reserved(), | |
| "max_allocated": torch.cuda.max_memory_allocated() | |
| } | |
| # Check system memory (simplified) | |
| import psutil | |
| memory_info["system_memory"] = { | |
| "used": psutil.virtual_memory().used, | |
| "total": psutil.virtual_memory().total, | |
| "percent": psutil.virtual_memory().percent | |
| } | |
| return HealthResponse( | |
| status="healthy" if model_instance is not None else "unhealthy", | |
| model_loaded=model_instance is not None, | |
| device=str(torch.device("cuda" if torch.cuda.is_available() else "cpu")), | |
| memory_usage=memory_info, | |
| uptime=datetime.now().isoformat() | |
| ) | |
| async def model_info(model: GASMEnhancedLLM = Depends(get_model)): | |
| """ | |
| Get model information | |
| """ | |
| return { | |
| "model_name": model.base_model_name, | |
| "geometry_enabled": model.enable_geometry, | |
| "device": str(next(model.parameters()).device), | |
| "total_parameters": sum(p.numel() for p in model.parameters()), | |
| "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), | |
| "model_size_mb": sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024), | |
| "gasm_config": { | |
| "hidden_dim": getattr(model.gasm_embedding.gasm, 'hidden_dim', None) if hasattr(model, 'gasm_embedding') else None, | |
| "output_dim": getattr(model.gasm_embedding.gasm, 'output_dim', None) if hasattr(model, 'gasm_embedding') else None, | |
| "max_iterations": getattr(model.gasm_embedding.gasm, 'max_iterations', None) if hasattr(model, 'gasm_embedding') else None, | |
| } | |
| } | |
| async def process_text( | |
| request: TextProcessingRequest, | |
| model: GASMEnhancedLLM = Depends(get_model) | |
| ): | |
| """ | |
| Process text with GASM-enhanced LLM | |
| """ | |
| start_time = datetime.now() | |
| try: | |
| # Configure model | |
| model.enable_geometry = request.enable_geometry | |
| # Process text | |
| outputs = model.encode_text( | |
| request.text, | |
| return_geometry=request.return_geometry | |
| ) | |
| # Calculate processing time | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| # Extract embeddings | |
| embeddings = outputs['last_hidden_state'] | |
| embedding_stats = { | |
| "shape": list(embeddings.shape), | |
| "mean": float(embeddings.mean()), | |
| "std": float(embeddings.std()), | |
| "min": float(embeddings.min()), | |
| "max": float(embeddings.max()), | |
| "norm": float(torch.norm(embeddings)) | |
| } | |
| # Prepare response | |
| response = TextProcessingResponse( | |
| success=True, | |
| timestamp=start_time.isoformat(), | |
| processing_time=processing_time, | |
| text_length=len(request.text), | |
| model_info={ | |
| "model_name": model.base_model_name, | |
| "geometry_enabled": request.enable_geometry, | |
| "device": str(next(model.parameters()).device) | |
| }, | |
| embedding_stats=embedding_stats | |
| ) | |
| # Add embeddings if requested | |
| if request.return_embeddings: | |
| response.embeddings = embeddings.detach().cpu().numpy().tolist() | |
| # Add geometric information if available | |
| if request.return_geometry and 'geometric_info' in outputs: | |
| geometric_info = outputs['geometric_info'] | |
| if geometric_info: | |
| response.geometric_info = { | |
| "num_sequences": len(geometric_info), | |
| "has_curvature": any('output' in info for info in geometric_info), | |
| "has_constraints": any('constraints' in info for info in geometric_info), | |
| "has_relations": any('relations' in info for info in geometric_info) | |
| } | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error processing text: {e}") | |
| return TextProcessingResponse( | |
| success=False, | |
| timestamp=start_time.isoformat(), | |
| processing_time=(datetime.now() - start_time).total_seconds(), | |
| text_length=len(request.text), | |
| model_info={}, | |
| embedding_stats={}, | |
| error=str(e) | |
| ) | |
| async def analyze_geometry( | |
| request: GeometricAnalysisRequest, | |
| model: GASMEnhancedLLM = Depends(get_model) | |
| ): | |
| """ | |
| Perform geometric analysis of text | |
| """ | |
| start_time = datetime.now() | |
| try: | |
| # Enable geometry for analysis | |
| model.enable_geometry = True | |
| # Process text with geometric information | |
| outputs = model.encode_text(request.text, return_geometry=True) | |
| response = GeometricAnalysisResponse( | |
| success=True, | |
| timestamp=start_time.isoformat(), | |
| analysis_type=request.analysis_type | |
| ) | |
| # Perform requested analysis | |
| if request.analysis_type in ["full", "curvature"]: | |
| # Curvature analysis | |
| geometric_info = outputs.get('geometric_info', []) | |
| if geometric_info: | |
| curvature_stats = [] | |
| for info in geometric_info: | |
| if 'output' in info: | |
| geo_output = info['output'] | |
| curvature_norm = torch.norm(geo_output, dim=1) | |
| curvature_stats.append({ | |
| "mean": float(curvature_norm.mean()), | |
| "std": float(curvature_norm.std()), | |
| "min": float(curvature_norm.min()), | |
| "max": float(curvature_norm.max()) | |
| }) | |
| response.curvature_analysis = { | |
| "per_sequence": curvature_stats, | |
| "global_stats": { | |
| "num_sequences": len(curvature_stats), | |
| "avg_mean_curvature": sum(s["mean"] for s in curvature_stats) / len(curvature_stats) if curvature_stats else 0 | |
| } | |
| } | |
| if request.analysis_type in ["full", "invariance"]: | |
| # SE(3) invariance analysis | |
| try: | |
| # Create simple test data for invariance check | |
| test_points = torch.randn(10, 3) | |
| test_features = torch.randn(10, model.base_model.config.hidden_size) | |
| test_relations = torch.randn(10, 10, 16) | |
| # Test with simplified model for invariance | |
| gasm_model = GASM( | |
| feature_dim=model.base_model.config.hidden_size, | |
| hidden_dim=256, | |
| output_dim=3 | |
| ) | |
| is_invariant = check_se3_invariance( | |
| gasm_model, | |
| test_points, | |
| test_features, | |
| test_relations, | |
| num_tests=request.num_invariance_tests, | |
| tolerance=request.tolerance | |
| ) | |
| response.invariance_results = { | |
| "is_invariant": is_invariant, | |
| "num_tests": request.num_invariance_tests, | |
| "tolerance": request.tolerance, | |
| "test_type": "SE(3) invariance" | |
| } | |
| except Exception as e: | |
| response.invariance_results = { | |
| "is_invariant": None, | |
| "error": str(e) | |
| } | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error in geometric analysis: {e}") | |
| return GeometricAnalysisResponse( | |
| success=False, | |
| timestamp=start_time.isoformat(), | |
| analysis_type=request.analysis_type, | |
| error=str(e) | |
| ) | |
| async def compare_processing( | |
| request: ComparisonRequest, | |
| model: GASMEnhancedLLM = Depends(get_model) | |
| ): | |
| """ | |
| Compare geometric vs standard processing | |
| """ | |
| start_time = datetime.now() | |
| try: | |
| # Process with geometry | |
| model.enable_geometry = True | |
| geometric_outputs = model.encode_text(request.text, return_geometry=True) | |
| # Process without geometry | |
| model.enable_geometry = False | |
| standard_outputs = model.encode_text(request.text, return_geometry=False) | |
| # Extract results | |
| geometric_embeddings = geometric_outputs['last_hidden_state'] | |
| standard_embeddings = standard_outputs['last_hidden_state'] | |
| # Calculate comparison metrics | |
| comparison_metrics = {} | |
| if "embedding_norm" in request.metrics: | |
| comparison_metrics["embedding_norm"] = { | |
| "geometric": float(torch.norm(geometric_embeddings)), | |
| "standard": float(torch.norm(standard_embeddings)), | |
| "ratio": float(torch.norm(geometric_embeddings) / torch.norm(standard_embeddings)) | |
| } | |
| if "attention_patterns" in request.metrics: | |
| # Simplified attention pattern comparison | |
| geo_attention = torch.softmax(geometric_embeddings @ geometric_embeddings.transpose(-2, -1), dim=-1) | |
| std_attention = torch.softmax(standard_embeddings @ standard_embeddings.transpose(-2, -1), dim=-1) | |
| comparison_metrics["attention_patterns"] = { | |
| "geometric_entropy": float(torch.sum(-geo_attention * torch.log(geo_attention + 1e-9))), | |
| "standard_entropy": float(torch.sum(-std_attention * torch.log(std_attention + 1e-9))), | |
| "pattern_difference": float(torch.norm(geo_attention - std_attention)) | |
| } | |
| if "geometric_consistency" in request.metrics: | |
| comparison_metrics["geometric_consistency"] = { | |
| "has_geometric_info": 'geometric_info' in geometric_outputs, | |
| "embedding_difference": float(torch.norm(geometric_embeddings - standard_embeddings)), | |
| "relative_change": float(torch.norm(geometric_embeddings - standard_embeddings) / torch.norm(standard_embeddings)) | |
| } | |
| return ComparisonResponse( | |
| success=True, | |
| timestamp=start_time.isoformat(), | |
| geometric_results={ | |
| "embedding_stats": { | |
| "shape": list(geometric_embeddings.shape), | |
| "mean": float(geometric_embeddings.mean()), | |
| "std": float(geometric_embeddings.std()), | |
| "norm": float(torch.norm(geometric_embeddings)) | |
| } | |
| }, | |
| standard_results={ | |
| "embedding_stats": { | |
| "shape": list(standard_embeddings.shape), | |
| "mean": float(standard_embeddings.mean()), | |
| "std": float(standard_embeddings.std()), | |
| "norm": float(torch.norm(standard_embeddings)) | |
| } | |
| }, | |
| comparison_metrics=comparison_metrics | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in comparison: {e}") | |
| return ComparisonResponse( | |
| success=False, | |
| timestamp=start_time.isoformat(), | |
| geometric_results={}, | |
| standard_results={}, | |
| comparison_metrics={}, | |
| error=str(e) | |
| ) | |
| async def batch_process( | |
| request: BatchProcessingRequest, | |
| model: GASMEnhancedLLM = Depends(get_model) | |
| ): | |
| """ | |
| Process multiple texts in batch | |
| """ | |
| start_time = datetime.now() | |
| try: | |
| model.enable_geometry = request.enable_geometry | |
| processing_times = [] | |
| individual_results = [] | |
| for i, text in enumerate(request.texts): | |
| text_start = datetime.now() | |
| outputs = model.encode_text(text, return_geometry=False) | |
| embeddings = outputs['last_hidden_state'] | |
| processing_time = (datetime.now() - text_start).total_seconds() | |
| processing_times.append(processing_time) | |
| if not request.return_summary: | |
| individual_results.append({ | |
| "text_index": i, | |
| "text_length": len(text), | |
| "processing_time": processing_time, | |
| "embedding_norm": float(torch.norm(embeddings)) | |
| }) | |
| # Calculate batch summary | |
| batch_summary = { | |
| "total_texts": len(request.texts), | |
| "total_processing_time": sum(processing_times), | |
| "average_processing_time": sum(processing_times) / len(processing_times), | |
| "texts_per_second": len(request.texts) / sum(processing_times), | |
| "geometry_enabled": request.enable_geometry, | |
| "total_characters": sum(len(text) for text in request.texts), | |
| "average_text_length": sum(len(text) for text in request.texts) / len(request.texts) | |
| } | |
| return BatchProcessingResponse( | |
| success=True, | |
| timestamp=start_time.isoformat(), | |
| num_texts=len(request.texts), | |
| processing_times=processing_times, | |
| batch_summary=batch_summary, | |
| individual_results=individual_results if not request.return_summary else None | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in batch processing: {e}") | |
| return BatchProcessingResponse( | |
| success=False, | |
| timestamp=start_time.isoformat(), | |
| num_texts=len(request.texts), | |
| processing_times=[], | |
| batch_summary={}, | |
| error=str(e) | |
| ) | |
| # Error handlers | |
| async def http_exception_handler(request, exc): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": exc.detail, "timestamp": datetime.now().isoformat()} | |
| ) | |
| async def general_exception_handler(request, exc): | |
| logger.error(f"Unhandled exception: {exc}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error", "timestamp": datetime.now().isoformat()} | |
| ) | |
| # OpenAPI customization for CustomGPT | |
| async def custom_openapi(): | |
| """ | |
| Custom OpenAPI schema for CustomGPT integration | |
| """ | |
| from fastapi.openapi.utils import get_openapi | |
| if app.openapi_schema: | |
| return app.openapi_schema | |
| openapi_schema = get_openapi( | |
| title="GASM-LLM API", | |
| version="1.0.0", | |
| description="API for GASM-enhanced Large Language Model processing with geometric inference capabilities", | |
| routes=app.routes, | |
| ) | |
| # Add custom metadata for CustomGPT | |
| openapi_schema["info"]["x-logo"] = { | |
| "url": "https://huggingface.co/spaces/your-username/gasm-llm/resolve/main/logo.png" | |
| } | |
| app.openapi_schema = openapi_schema | |
| return app.openapi_schema | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "fastapi_endpoint:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| log_level="info" | |
| ) |