import os import logging import time import asyncio from typing import List, Optional, Dict, Any from contextlib import asynccontextmanager import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig import uvicorn from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import gc # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for model and tokenizer model = None tokenizer = None device = None class QuestionGenerationRequest(BaseModel): statement: str = Field(..., description="The input statement to generate questions from") num_questions: int = Field(default=5, ge=1, le=10, description="Number of questions to generate (1-10)") temperature: float = Field(default=0.8, ge=0.1, le=2.0, description="Temperature for generation (0.1-2.0)") max_length: int = Field(default=2048, ge=100, le=4096, description="Maximum length of generated text") difficulty_level: str = Field(default="mixed", description="Difficulty level: easy, medium, hard, or mixed") class QuestionGenerationResponse(BaseModel): questions: List[str] statement: str metadata: Dict[str, Any] class HealthResponse(BaseModel): model_config = {"protected_namespaces": ()} status: str model_loaded: bool device: str memory_usage: Dict[str, float] async def load_model_with_retry(model_name: str, hf_token: str, max_retries: int = 3, delay: float = 5.0): """Load model with retry logic for network issues""" for attempt in range(max_retries): try: logger.info(f"Loading model attempt {attempt + 1}/{max_retries}: {model_name}") tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=True, trust_remote_code=True, token=hf_token ) # Use Seq2Seq model for T5-based models, CausalLM for others if "flan-t5" in model_name.lower() or "t5" in model_name.lower(): model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None, trust_remote_code=True, low_cpu_mem_usage=True, token=hf_token ) else: # Force model to load on cuda:0 specifically if device == "cuda": torch.cuda.set_device(0) device = "cuda:0" model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda:0" else torch.float32, device_map={"": 0} if device == "cuda:0" else None, # Force all parameters to GPU 0 trust_remote_code=True, low_cpu_mem_usage=True, use_safetensors=True, # Force safetensors to avoid CVE-2025-32434 token=hf_token, attn_implementation="eager" # Use eager attention (compatible) ) return tokenizer, model except Exception as e: logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") if attempt < max_retries - 1: logger.info(f"Retrying in {delay} seconds...") await asyncio.sleep(delay) delay *= 1.5 # Exponential backoff else: raise e async def load_model(): """Load the model and tokenizer""" global model, tokenizer, device try: logger.info("Starting model loading...") # Check if CUDA is available and force to cuda:0 if torch.cuda.is_available(): torch.cuda.set_device(0) device = "cuda:0" else: device = "cpu" logger.info(f"Using device: {device}") if device == "cuda:0": logger.info(f"GPU: {torch.cuda.get_device_name()}") logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") model_name = "DavidAU/Llama-3.1-1-million-ctx-DeepHermes-Deep-Reasoning-8B-GGUF" model_file = "Llama-3.1-1-million-ctx-DeepHermes-Deep-Reasoning-8B-Q4_K_M.gguf" # Get HF token from environment hf_token = os.getenv("HF_TOKEN") # Use transformers library with retry logic try: logger.info("Loading model with transformers...") # Use Llama 3.1 8B Instruct - 4x context window, better reasoning base_model_name = "meta-llama/Llama-3.1-8B-Instruct" tokenizer, model = await load_model_with_retry(base_model_name, hf_token) if device == "cuda:0": model = model.to(device) logger.info("Model loaded successfully with transformers!") except Exception as e: logger.error(f"Error loading model with transformers: {str(e)}") raise # Re-raise the error to stop startup if primary model fails except Exception as e: logger.error(f"Error loading model: {str(e)}") raise async def unload_model(): """Clean up model from memory""" global model, tokenizer try: if model is not None: del model if tokenizer is not None: del tokenizer # Clear CUDA cache if available if torch.cuda.is_available(): torch.cuda.empty_cache() # Force garbage collection gc.collect() logger.info("Model unloaded successfully") except Exception as e: logger.error(f"Error unloading model: {str(e)}") @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifespan""" # Startup logger.info("Starting up...") await load_model() yield # Shutdown logger.info("Shutting down...") await unload_model() # Create FastAPI app app = FastAPI( title="Question Generation API", description="API for generating questions from statements using DeepHermes reasoning model", version="1.0.0", lifespan=lifespan ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def create_question_prompt(statement: str, num_questions: int, difficulty_level: str) -> str: """Create a prompt for question generation optimized for Llama models""" difficulty_instruction = { "easy": "simple, straightforward questions that test basic understanding", "medium": "questions that require some analysis and comprehension", "hard": "complex questions that require deep thinking and reasoning", "mixed": "a mix of easy, medium, and hard questions" } # Llama models work better with chat-style prompts prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Please generate exactly {num_questions} {difficulty_instruction[difficulty_level]} based on this statement: "{statement}" Requirements: - Create clear, well-formed questions - Vary question types (what, how, why, when, where) - Number each question (1., 2., 3., etc.) - End each question with a question mark - Focus only on the content of the statement <|eot_id|><|start_header_id|>assistant<|end_header_id|> Here are {num_questions} questions based on the statement: """ return prompt def extract_questions(generated_text: str) -> List[str]: """Extract questions from the generated text""" questions = [] lines = generated_text.split('\n') for line in lines: line = line.strip() # Look for numbered questions if line and (line[0].isdigit() or line.startswith('Q')): # Remove numbering and clean up question = line # Remove common prefixes for prefix in ['1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', '10.', 'Q1:', 'Q2:', 'Q3:', 'Q4:', 'Q5:', 'Question 1:', 'Question 2:', 'Question 3:', 'Question 4:', 'Question 5:']: if question.startswith(prefix): question = question[len(prefix):].strip() break if question and question.endswith('?'): questions.append(question) # If no numbered questions found, try to extract any questions if not questions: for line in lines: line = line.strip() if line.endswith('?') and len(line) > 10: questions.append(line) return questions @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" global model memory_usage = {} if torch.cuda.is_available(): memory_usage = { "allocated_gb": torch.cuda.memory_allocated() / 1024**3, "reserved_gb": torch.cuda.memory_reserved() / 1024**3, "total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3 } return HealthResponse( status="healthy" if model is not None else "unhealthy", model_loaded=model is not None, device=device if device else "unknown", memory_usage=memory_usage ) @app.post("/generate-questions", response_model=QuestionGenerationResponse) async def generate_questions(request: QuestionGenerationRequest): """Generate questions from a statement""" global model if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: logger.info(f"Generating {request.num_questions} questions for statement: {request.statement[:100]}...") # Create prompt prompt = create_question_prompt( request.statement, request.num_questions, request.difficulty_level ) # Generate response using transformers inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) # Force all inputs to the same device as the model if device == "cuda:0": # Get the actual device of the model model_device = next(model.parameters()).device logger.info(f"Model is on device: {model_device}") # Move all input tensors to the same device as the model inputs = {k: v.to(model_device) for k, v in inputs.items()} with torch.no_grad(): # Optimized generation parameters for speed outputs = model.generate( **inputs, max_new_tokens=min(256, request.max_length // 4), # Reduced for speed temperature=request.temperature, top_p=0.9, # Slightly lower for faster sampling do_sample=True, num_beams=1, # Greedy search (fastest) pad_token_id=tokenizer.eos_token_id, early_stopping=True, use_cache=True, # Enable KV caching for speed repetition_penalty=1.1 ) # Decode the generated text and remove the input prompt full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the input prompt from the generated text generated_text = full_text[len(prompt):].strip() logger.info(f"Generated text length: {len(generated_text)}") # Extract questions from the generated text questions = extract_questions(generated_text) # Ensure we have the requested number of questions if len(questions) < request.num_questions: logger.warning(f"Only extracted {len(questions)} questions, requested {request.num_questions}") # Limit to requested number questions = questions[:request.num_questions] # If we still don't have enough questions, add a fallback while len(questions) < request.num_questions: questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?") metadata = { "model": "meta-llama/Llama-3.1-8B-Instruct", "temperature": request.temperature, "difficulty_level": request.difficulty_level, "generated_text_length": len(generated_text), "questions_extracted": len(questions) } logger.info(f"Successfully generated {len(questions)} questions") return QuestionGenerationResponse( questions=questions, statement=request.statement, metadata=metadata ) except Exception as e: logger.error(f"Error generating questions: {str(e)}") raise HTTPException(status_code=500, detail=f"Error generating questions: {str(e)}") @app.get("/") async def root(): """Root endpoint with basic info""" return { "message": "Question Generation API", "model": "google/flan-t5-large", "endpoints": { "health": "/health", "generate": "/generate-questions", "docs": "/docs" } } if __name__ == "__main__": uvicorn.run( "app:app", host="0.0.0.0", port=7860, reload=False )