Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import time | |
| import asyncio | |
| from typing import List, Optional, Dict, Any | |
| from contextlib import asynccontextmanager | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import gc | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| device = None | |
| class QuestionGenerationRequest(BaseModel): | |
| statement: str = Field(..., description="The input statement to generate questions from") | |
| num_questions: int = Field(default=5, ge=1, le=10, description="Number of questions to generate (1-10)") | |
| temperature: float = Field(default=0.8, ge=0.1, le=2.0, description="Temperature for generation (0.1-2.0)") | |
| max_length: int = Field(default=2048, ge=100, le=4096, description="Maximum length of generated text") | |
| difficulty_level: str = Field(default="mixed", description="Difficulty level: easy, medium, hard, or mixed") | |
| class QuestionGenerationResponse(BaseModel): | |
| questions: List[str] | |
| statement: str | |
| metadata: Dict[str, Any] | |
| class HealthResponse(BaseModel): | |
| model_config = {"protected_namespaces": ()} | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| memory_usage: Dict[str, float] | |
| async def load_model_with_retry(model_name: str, hf_token: str, max_retries: int = 3, delay: float = 5.0): | |
| """Load model with retry logic for network issues""" | |
| for attempt in range(max_retries): | |
| try: | |
| logger.info(f"Loading model attempt {attempt + 1}/{max_retries}: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| token=hf_token | |
| ) | |
| # Use Seq2Seq model for T5-based models, CausalLM for others | |
| if "flan-t5" in model_name.lower() or "t5" in model_name.lower(): | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto" if device == "cuda" else None, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| token=hf_token | |
| ) | |
| else: | |
| # Force model to load on cuda:0 specifically | |
| if device == "cuda": | |
| torch.cuda.set_device(0) | |
| device = "cuda:0" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if device == "cuda:0" else torch.float32, | |
| device_map={"": 0} if device == "cuda:0" else None, # Force all parameters to GPU 0 | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, # Force safetensors to avoid CVE-2025-32434 | |
| token=hf_token, | |
| attn_implementation="eager" # Use eager attention (compatible) | |
| ) | |
| return tokenizer, model | |
| except Exception as e: | |
| logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") | |
| if attempt < max_retries - 1: | |
| logger.info(f"Retrying in {delay} seconds...") | |
| await asyncio.sleep(delay) | |
| delay *= 1.5 # Exponential backoff | |
| else: | |
| raise e | |
| async def load_model(): | |
| """Load the model and tokenizer""" | |
| global model, tokenizer, device | |
| try: | |
| logger.info("Starting model loading...") | |
| # Check if CUDA is available and force to cuda:0 | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(0) | |
| device = "cuda:0" | |
| else: | |
| device = "cpu" | |
| logger.info(f"Using device: {device}") | |
| if device == "cuda:0": | |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") | |
| logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| model_name = "DavidAU/Llama-3.1-1-million-ctx-DeepHermes-Deep-Reasoning-8B-GGUF" | |
| model_file = "Llama-3.1-1-million-ctx-DeepHermes-Deep-Reasoning-8B-Q4_K_M.gguf" | |
| # Get HF token from environment | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Use transformers library with retry logic | |
| try: | |
| logger.info("Loading model with transformers...") | |
| # Use Llama 3.1 8B Instruct - 4x context window, better reasoning | |
| base_model_name = "meta-llama/Llama-3.1-8B-Instruct" | |
| tokenizer, model = await load_model_with_retry(base_model_name, hf_token) | |
| if device == "cuda:0": | |
| model = model.to(device) | |
| logger.info("Model loaded successfully with transformers!") | |
| except Exception as e: | |
| logger.error(f"Error loading model with transformers: {str(e)}") | |
| raise # Re-raise the error to stop startup if primary model fails | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise | |
| async def unload_model(): | |
| """Clean up model from memory""" | |
| global model, tokenizer | |
| try: | |
| if model is not None: | |
| del model | |
| if tokenizer is not None: | |
| del tokenizer | |
| # Clear CUDA cache if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Force garbage collection | |
| gc.collect() | |
| logger.info("Model unloaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error unloading model: {str(e)}") | |
| async def lifespan(app: FastAPI): | |
| """Manage application lifespan""" | |
| # Startup | |
| logger.info("Starting up...") | |
| await load_model() | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down...") | |
| await unload_model() | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Question Generation API", | |
| description="API for generating questions from statements using DeepHermes reasoning model", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def create_question_prompt(statement: str, num_questions: int, difficulty_level: str) -> str: | |
| """Create a prompt for question generation optimized for Llama models""" | |
| difficulty_instruction = { | |
| "easy": "simple, straightforward questions that test basic understanding", | |
| "medium": "questions that require some analysis and comprehension", | |
| "hard": "complex questions that require deep thinking and reasoning", | |
| "mixed": "a mix of easy, medium, and hard questions" | |
| } | |
| # Llama models work better with chat-style prompts | |
| prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
| Please generate exactly {num_questions} {difficulty_instruction[difficulty_level]} based on this statement: | |
| "{statement}" | |
| Requirements: | |
| - Create clear, well-formed questions | |
| - Vary question types (what, how, why, when, where) | |
| - Number each question (1., 2., 3., etc.) | |
| - End each question with a question mark | |
| - Focus only on the content of the statement | |
| <|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
| Here are {num_questions} questions based on the statement: | |
| """ | |
| return prompt | |
| def extract_questions(generated_text: str) -> List[str]: | |
| """Extract questions from the generated text""" | |
| questions = [] | |
| lines = generated_text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| # Look for numbered questions | |
| if line and (line[0].isdigit() or line.startswith('Q')): | |
| # Remove numbering and clean up | |
| question = line | |
| # Remove common prefixes | |
| for prefix in ['1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', '10.', 'Q1:', 'Q2:', 'Q3:', 'Q4:', 'Q5:', 'Question 1:', 'Question 2:', 'Question 3:', 'Question 4:', 'Question 5:']: | |
| if question.startswith(prefix): | |
| question = question[len(prefix):].strip() | |
| break | |
| if question and question.endswith('?'): | |
| questions.append(question) | |
| # If no numbered questions found, try to extract any questions | |
| if not questions: | |
| for line in lines: | |
| line = line.strip() | |
| if line.endswith('?') and len(line) > 10: | |
| questions.append(line) | |
| return questions | |
| async def health_check(): | |
| """Health check endpoint""" | |
| global model | |
| memory_usage = {} | |
| if torch.cuda.is_available(): | |
| memory_usage = { | |
| "allocated_gb": torch.cuda.memory_allocated() / 1024**3, | |
| "reserved_gb": torch.cuda.memory_reserved() / 1024**3, | |
| "total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| } | |
| return HealthResponse( | |
| status="healthy" if model is not None else "unhealthy", | |
| model_loaded=model is not None, | |
| device=device if device else "unknown", | |
| memory_usage=memory_usage | |
| ) | |
| async def generate_questions(request: QuestionGenerationRequest): | |
| """Generate questions from a statement""" | |
| global model | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| logger.info(f"Generating {request.num_questions} questions for statement: {request.statement[:100]}...") | |
| # Create prompt | |
| prompt = create_question_prompt( | |
| request.statement, | |
| request.num_questions, | |
| request.difficulty_level | |
| ) | |
| # Generate response using transformers | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| # Force all inputs to the same device as the model | |
| if device == "cuda:0": | |
| # Get the actual device of the model | |
| model_device = next(model.parameters()).device | |
| logger.info(f"Model is on device: {model_device}") | |
| # Move all input tensors to the same device as the model | |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| # Optimized generation parameters for speed | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=min(256, request.max_length // 4), # Reduced for speed | |
| temperature=request.temperature, | |
| top_p=0.9, # Slightly lower for faster sampling | |
| do_sample=True, | |
| num_beams=1, # Greedy search (fastest) | |
| pad_token_id=tokenizer.eos_token_id, | |
| early_stopping=True, | |
| use_cache=True, # Enable KV caching for speed | |
| repetition_penalty=1.1 | |
| ) | |
| # Decode the generated text and remove the input prompt | |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the input prompt from the generated text | |
| generated_text = full_text[len(prompt):].strip() | |
| logger.info(f"Generated text length: {len(generated_text)}") | |
| # Extract questions from the generated text | |
| questions = extract_questions(generated_text) | |
| # Ensure we have the requested number of questions | |
| if len(questions) < request.num_questions: | |
| logger.warning(f"Only extracted {len(questions)} questions, requested {request.num_questions}") | |
| # Limit to requested number | |
| questions = questions[:request.num_questions] | |
| # If we still don't have enough questions, add a fallback | |
| while len(questions) < request.num_questions: | |
| questions.append(f"What is the main point of this statement: '{request.statement[:100]}...'?") | |
| metadata = { | |
| "model": "meta-llama/Llama-3.1-8B-Instruct", | |
| "temperature": request.temperature, | |
| "difficulty_level": request.difficulty_level, | |
| "generated_text_length": len(generated_text), | |
| "questions_extracted": len(questions) | |
| } | |
| logger.info(f"Successfully generated {len(questions)} questions") | |
| return QuestionGenerationResponse( | |
| questions=questions, | |
| statement=request.statement, | |
| metadata=metadata | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating questions: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating questions: {str(e)}") | |
| async def root(): | |
| """Root endpoint with basic info""" | |
| return { | |
| "message": "Question Generation API", | |
| "model": "google/flan-t5-large", | |
| "endpoints": { | |
| "health": "/health", | |
| "generate": "/generate-questions", | |
| "docs": "/docs" | |
| } | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=False | |
| ) |