from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict, Any import torch from transformers import AutoTokenizer, AutoModelForCausalLM import uvicorn import logging from contextlib import asynccontextmanager # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for model and tokenizer model = None tokenizer = None # Request/Response models class ChatMessage(BaseModel): role: str # "system", "user", "assistant" content: str class ChatRequest(BaseModel): messages: List[ChatMessage] max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 stop: Optional[List[str]] = None class ChatResponse(BaseModel): content: str finish_reason: str usage: Dict[str, int] class ChatStreamChunk(BaseModel): content: str finish_reason: Optional[str] = None usage: Optional[Dict[str, int]] = None @asynccontextmanager async def lifespan(app: FastAPI): # Load model on startup global model, tokenizer logger.info("Loading model and tokenizer...") # SOLUTION 1: Use a more compatible model # Replace Qwen3-4B with a widely supported model # model_name = "microsoft/DialoGPT-medium" # Alternative: "gpt2", "microsoft/DialoGPT-small" model_name = "Qwen/Qwen2.5-7B-Instruct" # Alternative: "gpt2", "microsoft/DialoGPT-small" # SOLUTION 2: If you want to use Qwen models, try these alternatives: # model_name = "Qwen/Qwen1.5-0.5B-Chat" # Smaller, more compatible Qwen model # model_name = "Qwen/Qwen2-0.5B-Instruct" # Even smaller option try: # SOLUTION 3: Add trust_remote_code=True and use_fast=False for better compatibility tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, use_fast=False # Use slow tokenizer for better compatibility ) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) # Set pad token if not present if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info(f"Model loaded successfully: {model_name}") except Exception as e: logger.error(f"Failed to load model: {e}") # SOLUTION 4: Fallback to a guaranteed working model logger.info("Attempting fallback to GPT-2...") try: model_name = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info(f"Fallback model loaded successfully: {model_name}") except Exception as fallback_error: logger.error(f"Fallback model also failed: {fallback_error}") raise fallback_error yield # Cleanup logger.info("Shutting down...") # Initialize FastAPI app app = FastAPI( title="Custom Chat Model API", description="API for fine-tuned chat model", version="1.0.0", lifespan=lifespan ) def format_messages(messages: List[ChatMessage]) -> str: """Format messages into a prompt string""" formatted_prompt = "" for message in messages: if message.role == "system": formatted_prompt += f"System: {message.content}\n" elif message.role == "user": formatted_prompt += f"User: {message.content}\n" elif message.role == "assistant": formatted_prompt += f"Assistant: {message.content}\n" # Add assistant prompt for completion formatted_prompt += "Assistant:" return formatted_prompt def generate_response( prompt: str, max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, stop: Optional[List[str]] = None ) -> tuple[str, Dict[str, int]]: """Generate response using the loaded model""" # Handle device placement more robustly device = next(model.parameters()).device inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) input_length = input_ids.shape[1] # Generate response with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1 ) # Decode only the generated part generated_ids = outputs[0][input_length:] response = tokenizer.decode(generated_ids, skip_special_tokens=True) # Handle stop tokens if stop: for stop_token in stop: if stop_token in response: response = response.split(stop_token)[0] break # Calculate tokens output_tokens = len(tokenizer.encode(response)) usage = { "input_tokens": input_length, "output_tokens": output_tokens, "total_tokens": input_length + output_tokens } return response.strip(), usage @app.get("/") async def root(): return {"message": "Custom Chat Model API", "status": "running"} @app.get("/health") async def health_check(): return {"status": "healthy", "model_loaded": model is not None} @app.post("/chat/completions", response_model=ChatResponse) async def chat_completions(request: ChatRequest): """Main chat completion endpoint""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Format messages into prompt prompt = format_messages(request.messages) # Generate response response_content, usage = generate_response( prompt=prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, stop=request.stop ) return ChatResponse( content=response_content, finish_reason="stop", usage=usage ) except Exception as e: logger.error(f"Error in chat completion: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/chat/stream") async def chat_stream(request: ChatRequest): """Streaming chat completion endpoint""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") try: from fastapi.responses import StreamingResponse import json def generate_stream(): prompt = format_messages(request.messages) # For simplicity, we'll simulate streaming by chunking the response # In a real implementation, you'd use model.generate with streaming response_content, usage = generate_response( prompt=prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, stop=request.stop ) # Split response into chunks words = response_content.split() for i, word in enumerate(words): chunk = ChatStreamChunk( content=word + " " if i < len(words) - 1 else word, finish_reason=None ) yield f"data: {json.dumps(chunk.dict())}\n\n" # Final chunk with usage info final_chunk = ChatStreamChunk( content="", finish_reason="stop", usage=usage ) yield f"data: {json.dumps(final_chunk.dict())}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( generate_stream(), media_type="text/plain", headers={"Cache-Control": "no-cache"} ) except Exception as e: logger.error(f"Error in streaming: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)