| """ |
| FastAPI Backend AI Service converted from Gradio app |
| Provides OpenAI-compatible chat completion endpoints |
| """ |
|
|
| import asyncio |
| import logging |
| import time |
| import json |
| from contextlib import asynccontextmanager |
| from typing import List, Dict, Any, Optional, AsyncGenerator, Union |
|
|
| from fastapi import FastAPI, HTTPException, Depends, Request |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field, field_validator |
| from huggingface_hub import InferenceClient |
| import uvicorn |
| import requests |
| from PIL import Image |
|
|
| |
| try: |
| from transformers import pipeline, AutoTokenizer |
| transformers_available = True |
| except ImportError: |
| transformers_available = False |
| pipeline = None |
| AutoTokenizer = None |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| class TextContent(BaseModel): |
| type: str = Field(default="text", description="Content type") |
| text: str = Field(..., description="Text content") |
| |
| @field_validator('type') |
| @classmethod |
| def validate_type(cls, v: str) -> str: |
| if v != "text": |
| raise ValueError("Type must be 'text'") |
| return v |
|
|
| class ImageContent(BaseModel): |
| type: str = Field(default="image", description="Content type") |
| url: str = Field(..., description="Image URL") |
| |
| @field_validator('type') |
| @classmethod |
| def validate_type(cls, v: str) -> str: |
| if v != "image": |
| raise ValueError("Type must be 'image'") |
| return v |
|
|
| |
| class ChatMessage(BaseModel): |
| role: str = Field(..., description="The role of the message author") |
| content: Union[str, List[Union[TextContent, ImageContent]]] = Field(..., description="The content of the message - either string or list of content items") |
| |
| @field_validator('role') |
| @classmethod |
| def validate_role(cls, v: str) -> str: |
| if v not in ["system", "user", "assistant"]: |
| raise ValueError("Role must be one of: system, user, assistant") |
| return v |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str = Field(default="gemma-3n-E4B-it-GGUF", description="The model to use for completion") |
| messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=2048, description="Maximum tokens to generate") |
| temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") |
| stream: Optional[bool] = Field(default=False, description="Whether to stream responses") |
| top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") |
|
|
| class ChatCompletionChoice(BaseModel): |
| index: int |
| message: ChatMessage |
| finish_reason: str |
|
|
| class ChatCompletionResponse(BaseModel): |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[ChatCompletionChoice] |
|
|
| class ChatCompletionChunk(BaseModel): |
| id: str |
| object: str = "chat.completion.chunk" |
| created: int |
| model: str |
| choices: List[Dict[str, Any]] |
|
|
| class HealthResponse(BaseModel): |
| status: str |
| model: str |
| version: str |
|
|
| class ModelInfo(BaseModel): |
| id: str |
| object: str = "model" |
| created: int |
| owned_by: str = "huggingface" |
|
|
| class ModelsResponse(BaseModel): |
| object: str = "list" |
| data: List[ModelInfo] |
|
|
| class CompletionRequest(BaseModel): |
| prompt: str = Field(..., description="The prompt to complete") |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=2048) |
| temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) |
|
|
| |
| inference_client: Optional[InferenceClient] = None |
| image_text_pipeline = None |
| current_model = "gemma-3n-E4B-it-GGUF" |
| vision_model = "Salesforce/blip-image-captioning-base" |
| tokenizer = None |
|
|
| |
| async def download_image(url: str) -> Image.Image: |
| """Download and process image from URL""" |
| try: |
| response = requests.get(url, timeout=10) |
| response.raise_for_status() |
| image = Image.open(requests.compat.BytesIO(response.content)) |
| return image |
| except Exception as e: |
| logger.error(f"Failed to download image from {url}: {e}") |
| raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}") |
|
|
| def extract_text_and_images(content: Union[str, List[Any]]) -> tuple[str, List[str]]: |
| """Extract text and image URLs from message content""" |
| if isinstance(content, str): |
| return content, [] |
| |
| text_parts: List[str] = [] |
| image_urls: List[str] = [] |
| |
| for item in content: |
| if hasattr(item, 'type'): |
| if item.type == "text" and hasattr(item, 'text'): |
| text_parts.append(str(item.text)) |
| elif item.type == "image" and hasattr(item, 'url'): |
| image_urls.append(str(item.url)) |
| |
| return " ".join(text_parts), image_urls |
|
|
| def has_images(messages: List[ChatMessage]) -> bool: |
| """Check if any messages contain images""" |
| for message in messages: |
| if isinstance(message.content, list): |
| for item in message.content: |
| if hasattr(item, 'type') and item.type == "image": |
| return True |
| return False |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Application lifespan manager for startup and shutdown events""" |
| global inference_client, tokenizer, image_text_pipeline |
| |
| |
| logger.info("🚀 Starting AI Backend Service...") |
| try: |
| |
| inference_client = InferenceClient(model=current_model) |
| logger.info(f"✅ Initialized inference client with model: {current_model}") |
| |
| |
| if transformers_available and pipeline: |
| try: |
| logger.info(f"🖼️ Initializing image captioning pipeline with model: {vision_model}") |
| image_text_pipeline = pipeline("image-to-text", model=vision_model) |
| logger.info("✅ Image captioning pipeline loaded successfully") |
| except Exception as e: |
| logger.warning(f"⚠️ Could not load image captioning pipeline: {e}") |
| image_text_pipeline = None |
| else: |
| logger.warning("⚠️ Transformers not available, image processing disabled") |
| image_text_pipeline = None |
| |
| |
| if transformers_available and AutoTokenizer: |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(current_model) |
| logger.info("✅ Tokenizer loaded successfully") |
| except Exception as e: |
| logger.warning(f"⚠️ Could not load tokenizer: {e}") |
| tokenizer = None |
| else: |
| logger.info("⚠️ Tokenizer initialization skipped") |
| |
| except Exception as e: |
| logger.error(f"❌ Failed to initialize inference client: {e}") |
| raise RuntimeError(f"Service initialization failed: {e}") |
| |
| yield |
| |
| |
| logger.info("🔄 Shutting down AI Backend Service...") |
| inference_client = None |
| tokenizer = None |
| image_text_pipeline = None |
|
|
| |
| app = FastAPI( |
| title="AI Backend Service", |
| description="OpenAI-compatible chat completion API powered by HuggingFace", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| def get_inference_client() -> InferenceClient: |
| """Dependency to get the inference client""" |
| if inference_client is None: |
| raise HTTPException(status_code=503, detail="Service not ready - inference client not initialized") |
| return inference_client |
|
|
| def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: |
| """Convert OpenAI messages format to a single prompt string""" |
| prompt_parts: List[str] = [] |
| |
| for message in messages: |
| role = message.role |
| |
| |
| if isinstance(message.content, str): |
| content = message.content |
| else: |
| content, _ = extract_text_and_images(message.content) |
| |
| if role == "system": |
| prompt_parts.append(f"System: {content}") |
| elif role == "user": |
| prompt_parts.append(f"Human: {content}") |
| elif role == "assistant": |
| prompt_parts.append(f"Assistant: {content}") |
| |
| |
| prompt_parts.append("Assistant:") |
| |
| return "\n".join(prompt_parts) |
|
|
| async def generate_multimodal_response( |
| messages: List[ChatMessage], |
| request: ChatCompletionRequest |
| ) -> str: |
| """Generate response using image-text-to-text pipeline for multimodal content""" |
| if not image_text_pipeline: |
| raise HTTPException(status_code=503, detail="Image processing not available - pipeline not initialized") |
| |
| try: |
| |
| last_user_message = None |
| for message in reversed(messages): |
| if message.role == "user" and isinstance(message.content, list): |
| last_user_message = message |
| break |
| |
| if not last_user_message: |
| raise HTTPException(status_code=400, detail="No user message with images found") |
| |
| |
| text_content, image_urls = extract_text_and_images(last_user_message.content) |
| |
| if not image_urls: |
| raise HTTPException(status_code=400, detail="No images found in the message") |
| |
| |
| image_url = image_urls[0] |
| |
| |
| logger.info(f"🖼️ Processing image: {image_url}") |
| try: |
| |
| result = await asyncio.to_thread(lambda: image_text_pipeline(image_url)) |
| |
| |
| if result and hasattr(result, '__len__') and len(result) > 0: |
| first_result = result[0] |
| if hasattr(first_result, 'get'): |
| generated_text = first_result.get('generated_text', f'I can see an image at {image_url}.') |
| else: |
| generated_text = str(first_result) |
| |
| |
| if text_content: |
| response = f"Looking at this image, I can see: {generated_text}. " |
| if "what" in text_content.lower() or "?" in text_content: |
| response += f"Regarding your question '{text_content}': Based on what I can see, this appears to be {generated_text.lower()}." |
| else: |
| response += f"You mentioned: {text_content}" |
| return response |
| else: |
| return f"I can see: {generated_text}" |
| else: |
| return f"I can see there's an image at {image_url}, but cannot process it right now." |
| |
| except Exception as pipeline_error: |
| logger.warning(f"Pipeline error: {pipeline_error}") |
| return f"I can see there's an image at {image_url}. The image appears to contain visual content that I'm having trouble processing right now." |
| |
| except Exception as e: |
| logger.error(f"Error in multimodal generation: {e}") |
| return f"I'm having trouble processing the image. Error: {str(e)}" |
|
|
| def generate_response_safe(client: InferenceClient, prompt: str, max_tokens: int, temperature: float, top_p: float) -> str: |
| """Safely generate response from the model with fallback methods""" |
| try: |
| |
| response_text = client.text_generation( |
| prompt=prompt, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| return_full_text=False, |
| stop=["Human:", "System:"] |
| ) |
| return response_text.strip() if response_text else "I apologize, but I couldn't generate a response." |
| |
| except Exception as e: |
| logger.warning(f"text_generation failed: {e}") |
| |
| |
| try: |
| response_text = client.text_generation( |
| prompt=prompt, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| return_full_text=False |
| ) |
| return response_text.strip() if response_text else "I apologize, but I couldn't generate a response." |
| |
| except Exception as e2: |
| logger.error(f"All generation methods failed: {e2}") |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." |
|
|
| async def generate_streaming_response( |
| client: InferenceClient, |
| prompt: str, |
| request: ChatCompletionRequest |
| ) -> AsyncGenerator[str, None]: |
| """Generate streaming response from the model""" |
| |
| request_id = f"chatcmpl-{int(time.time())}" |
| created = int(time.time()) |
| |
| try: |
| |
| response_text = await asyncio.to_thread( |
| generate_response_safe, |
| client, |
| prompt, |
| request.max_tokens or 512, |
| request.temperature or 0.7, |
| request.top_p or 0.95 |
| ) |
| |
| |
| words = response_text.split() if response_text else ["No", "response", "generated"] |
| for i, word in enumerate(words): |
| chunk = ChatCompletionChunk( |
| id=request_id, |
| created=created, |
| model=request.model, |
| choices=[{ |
| "index": 0, |
| "delta": {"content": f" {word}" if i > 0 else word}, |
| "finish_reason": None |
| }] |
| ) |
| |
| yield f"data: {chunk.model_dump_json()}\n\n" |
| await asyncio.sleep(0.05) |
| |
| |
| final_chunk = ChatCompletionChunk( |
| id=request_id, |
| created=created, |
| model=request.model, |
| choices=[{ |
| "index": 0, |
| "delta": {}, |
| "finish_reason": "stop" |
| }] |
| ) |
| |
| yield f"data: {final_chunk.model_dump_json()}\n\n" |
| yield "data: [DONE]\n\n" |
| |
| except Exception as e: |
| logger.error(f"Error in streaming generation: {e}") |
| error_chunk: Dict[str, Any] = { |
| "id": request_id, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": request.model, |
| "choices": [{ |
| "index": 0, |
| "delta": {}, |
| "finish_reason": "error" |
| }], |
| "error": str(e) |
| } |
| yield f"data: {json.dumps(error_chunk)}\n\n" |
|
|
| @app.get("/", response_class=JSONResponse) |
| async def root() -> Dict[str, Any]: |
| """Root endpoint with service information""" |
| return { |
| "message": "AI Backend Service is running!", |
| "version": "1.0.0", |
| "endpoints": { |
| "health": "/health", |
| "models": "/v1/models", |
| "chat_completions": "/v1/chat/completions" |
| } |
| } |
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| """Health check endpoint""" |
| global current_model |
| return HealthResponse( |
| status="healthy" if inference_client else "unhealthy", |
| model=current_model, |
| version="1.0.0" |
| ) |
|
|
| @app.get("/v1/models", response_model=ModelsResponse) |
| async def list_models(): |
| """List available models (OpenAI-compatible)""" |
| |
| models = [ |
| ModelInfo( |
| id=current_model, |
| created=int(time.time()), |
| owned_by="huggingface" |
| ) |
| ] |
| |
| |
| if image_text_pipeline: |
| models.append( |
| ModelInfo( |
| id=vision_model, |
| created=int(time.time()), |
| owned_by="huggingface" |
| ) |
| ) |
| |
| return ModelsResponse(data=models) |
|
|
| @app.post("/v1/chat/completions") |
| async def create_chat_completion( |
| request: ChatCompletionRequest, |
| client: InferenceClient = Depends(get_inference_client) |
| ): |
| """Create a chat completion (OpenAI-compatible) with multimodal support""" |
| try: |
| |
| if not request.messages: |
| raise HTTPException(status_code=400, detail="Messages cannot be empty") |
| |
| |
| is_multimodal = has_images(request.messages) |
| |
| if is_multimodal: |
| |
| if not image_text_pipeline: |
| raise HTTPException(status_code=503, detail="Image processing not available") |
| |
| response_text = await generate_multimodal_response(request.messages, request) |
| else: |
| |
| prompt = convert_messages_to_prompt(request.messages) |
| logger.info(f"Generated prompt: {prompt[:200]}...") |
| |
| if request.stream: |
| |
| return StreamingResponse( |
| generate_streaming_response(client, prompt, request), |
| media_type="text/plain", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Content-Type": "text/plain; charset=utf-8" |
| } |
| ) |
| else: |
| |
| response_text = await asyncio.to_thread( |
| generate_response_safe, |
| client, |
| prompt, |
| request.max_tokens or 512, |
| request.temperature or 0.7, |
| request.top_p or 0.95 |
| ) |
| |
| |
| response_text = response_text.strip() if response_text else "No response generated." |
| |
| |
| response = ChatCompletionResponse( |
| id=f"chatcmpl-{int(time.time())}", |
| created=int(time.time()), |
| model=request.model, |
| choices=[ |
| ChatCompletionChoice( |
| index=0, |
| message=ChatMessage(role="assistant", content=response_text), |
| finish_reason="stop" |
| ) |
| ] |
| ) |
| |
| return response |
| |
| except Exception as e: |
| logger.error(f"Error in chat completion: {e}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
| @app.post("/v1/completions") |
| async def create_completion( |
| request: CompletionRequest, |
| client: InferenceClient = Depends(get_inference_client) |
| ) -> Dict[str, Any]: |
| """Create a text completion (OpenAI-compatible)""" |
| try: |
| if not request.prompt: |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") |
| |
| |
| response_text = await asyncio.to_thread( |
| generate_response_safe, |
| client, |
| request.prompt, |
| request.max_tokens or 512, |
| request.temperature or 0.7, |
| 0.95 |
| ) |
| |
| return { |
| "id": f"cmpl-{int(time.time())}", |
| "object": "text_completion", |
| "created": int(time.time()), |
| "model": current_model, |
| "choices": [{ |
| "text": response_text, |
| "index": 0, |
| "finish_reason": "stop" |
| }] |
| } |
| |
| except Exception as e: |
| logger.error(f"Error in completion: {e}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
| @app.post("/api/response") |
| async def api_response(request: Request): |
| """Endpoint to receive and send responses via API.""" |
| try: |
| data = await request.json() |
| message = data.get("message", "No message provided") |
| response: dict[str, str] = { |
| "status": "success", |
| "received_message": message, |
| "response_message": f"You sent: {message}" |
| } |
| return JSONResponse(content=response) |
| except Exception as e: |
| logger.error(f"Error processing API response: {e}") |
| raise HTTPException(status_code=500, detail="Internal server error") |
|
|
| @app.exception_handler(Exception) |
| async def global_exception_handler(request: Any, exc: Exception) -> JSONResponse: |
| """Global exception handler""" |
| logger.error(f"Unhandled exception: {exc}") |
| return JSONResponse( |
| status_code=500, |
| content={"detail": f"Internal server error: {str(exc)}"} |
| ) |
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="AI Backend Service") |
| parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") |
| parser.add_argument("--port", type=int, default=8000, help="Port to bind to") |
| parser.add_argument("--model", default=current_model, help="HuggingFace model to use") |
| parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development") |
| |
| args = parser.parse_args() |
| |
| if args.model != current_model: |
| current_model = args.model |
| logger.info(f"Using model: {current_model}") |
| |
| logger.info(f"🚀 Starting AI Backend Service on {args.host}:{args.port}") |
| |
| uvicorn.run( |
| "backend_service:app", |
| host=args.host, |
| port=args.port, |
| reload=args.reload, |
| log_level="info" |
| ) |
|
|