Zeggai Abdellah
first commit
ffaeec5
raw
history blame
12.4 kB
from fastapi import FastAPI, HTTPException, BackgroundTasks, Query
from fastapi.responses import JSONResponse
from typing import List, Dict, Optional
import json
import time
import uuid
from datetime import datetime
import os
from pydantic import BaseModel
import google.generativeai as genai
from enum import Enum
import asyncio
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="Vaccine Question Generator API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
# Global variables to track generation state
generation_status = {
"is_running": False,
"total_chunks": 0,
"processed_chunks": 0,
"current_chunk_id": None,
"start_time": None,
"end_time": None,
"errors": [],
"result_file": None
}
# Chunks file path (will be configurable via API)
CHUNKS_PATH = "Data/Processed_Data/chunks.json"
# API Key (will be set via environment variable or API)
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY", "")
# Model type options
class ModelType(str, Enum):
GEMINI_FLASH = "gemini-2.0-flash"
GEMINI_PRO = "gemini-1.5-pro"
# Request schema for starting generation
class GenerationRequest(BaseModel):
chunks_path: Optional[str] = None
api_key: Optional[str] = None
model: ModelType = ModelType.GEMINI_FLASH
output_file: str = "vaccine_questions_dataset.json"
# Response schema for status updates
class GenerationStatus(BaseModel):
is_running: bool
total_chunks: int
processed_chunks: int
current_chunk_id: Optional[int]
progress_percentage: float
start_time: Optional[str]
end_time: Optional[str]
estimated_time_remaining: Optional[str]
errors: List[str]
result_file: Optional[str]
def estimate_difficulty(question: str, q_type: str) -> str:
"""
Estimate question difficulty based on type and content.
Args:
question (str): The question text.
q_type (str): Question type (factual, conceptual, applied).
Returns:
str: Difficulty level (easy, medium, hard).
"""
if q_type == "factual":
return "easy"
elif q_type == "conceptual":
return "medium"
return "hard" # applied
async def generate_questions_for_chunk(chunk: str, chunk_id: int, client, model: str) -> List[Dict]:
"""
Generate French questions for a given document chunk using the Gemini API.
Args:
chunk (str): A chunk of text from the vaccine guide (in French).
chunk_id (int): Chunk identifier.
client: Gemini API client instance.
model (str): Model name for Gemini API.
Returns:
List[Dict]: List of questions with metadata.
"""
prompt = f"""
À partir du texte suivant d'un guide sur les vaccins en français, générez 3 questions variées (factual, conceptual, applied) qui couvrent le contenu de manière exhaustive.
Fournissez uniquement les questions, sans réponses, en français. Retournez le résultat au format JSON, entouré de ```json\n...\n```.
Texte : {chunk}
Exemple de sortie :
```json
[
{{
"question": "Combien de structures sanitaires de proximité sont impliquées dans le suivi de la vaccination ?",
"type": "factual"
}},
{{
"question": "Quel est l'impact de la réglementation de la vaccination sur la couverture vaccinale ?",
"type": "conceptual"
}},
{{
"question": "Quelles seraient les conséquences si les établissements privés ne suivaient plus la réglementation vaccinale ?",
"type": "applied"
}}
]
```
"""
try:
# Update global state
generation_status["current_chunk_id"] = chunk_id
# Generate response using Gemini
response = client.generate_content(
model=model,
contents=prompt,
)
# Parse the response
questions_text = response.text if hasattr(response, 'text') else ""
# Strip Markdown code fences
if questions_text.startswith("```json\n") and questions_text.endswith("\n```"):
questions_text = questions_text[7:-4].strip()
elif questions_text.startswith("```") and questions_text.endswith("```"):
questions_text = questions_text[3:-3].strip()
# Parse JSON
if not questions_text:
error_msg = f"Erreur: Réponse vide pour le chunk {chunk_id}"
generation_status["errors"].append(error_msg)
return []
questions = json.loads(questions_text)
formatted_questions = []
for q in questions:
question_id = str(uuid.uuid4())
difficulty = estimate_difficulty(q["question"], q["type"])
formatted_questions.append({
"question_id": question_id,
"chunk_id": chunk_id,
"chunk_text": chunk,
"question": q["question"],
"type": q["type"],
"difficulty": difficulty,
"training_purpose": "Knowledge Recall" if q["type"] == "factual" else "Reasoning",
"validated": False # Flag for expert review
})
# Update count of processed chunks
generation_status["processed_chunks"] += 1
return formatted_questions
except Exception as e:
error_msg = f"Error generating questions for chunk {chunk_id}: {str(e)}"
generation_status["errors"].append(error_msg)
return []
async def generate_questions_for_document(chunks: List[str], model: str, output_file: str, client) -> Dict:
"""
Generate questions for all document chunks and structure as a scientific dataset.
Args:
chunks (List[str]): List of document chunks.
model (str): Model name for Gemini API.
output_file (str): File to save the results.
client: Gemini API client.
Returns:
Dict: Dataset with header and questions.
"""
all_questions = []
# Reset/initialize the global state
generation_status["is_running"] = True
generation_status["total_chunks"] = len(chunks)
generation_status["processed_chunks"] = 0
generation_status["start_time"] = datetime.utcnow().isoformat()
generation_status["errors"] = []
generation_status["current_chunk_id"] = None
generation_status["end_time"] = None
generation_status["result_file"] = None
try:
for i, chunk in enumerate(chunks):
# Process each chunk
questions = await generate_questions_for_chunk(chunk, i, client, model)
all_questions.extend(questions)
# Rate limiting
await asyncio.sleep(9)
# Create dataset with scientific structure
dataset = {
"dataset_info": {
"title": "Vaccine Guide Question-Answer Dataset",
"description": "A dataset of question-answer pairs generated from a vaccine guide for AI language model training.",
"version": "1.1.0",
"created_date": datetime.utcnow().isoformat(),
"source": "Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.pdf",
"generated_by": f"Gemini API ({model})",
"total_questions": len(all_questions),
"intended_use": "Fine-tuning medical language models for knowledge recall and reasoning"
},
"questions": all_questions
}
# Save the dataset
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(dataset, f, indent=4, ensure_ascii=False)
# Update final state
generation_status["end_time"] = datetime.utcnow().isoformat()
generation_status["result_file"] = output_file
return dataset
except Exception as e:
generation_status["errors"].append(f"Error in document generation: {str(e)}")
raise e
finally:
generation_status["is_running"] = False
async def background_generation_task(chunks_path: str, model: str, output_file: str, api_key: str = None):
"""Background task for generating questions"""
try:
# Configure the client
if api_key:
genai.configure(api_key=api_key)
elif GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
else:
raise ValueError("No API key provided for Gemini")
# Load chunks
with open(chunks_path, "r", encoding="utf-8") as f:
chunks_data = json.load(f)
# Extract texts from chunks
chunks = [chunk["text"] for chunk in chunks_data]
# Start generation process
await generate_questions_for_document(chunks, model, output_file, genai)
except Exception as e:
generation_status["errors"].append(f"Background task error: {str(e)}")
generation_status["is_running"] = False
@app.post("/generate", response_model=GenerationStatus)
async def start_generation(request: GenerationRequest, background_tasks: BackgroundTasks):
"""Start the question generation process"""
# Check if generation is already running
if generation_status["is_running"]:
raise HTTPException(status_code=400, detail="Generation process is already running")
# Set up paths and configurations
chunks_path = request.chunks_path or CHUNKS_PATH
api_key = request.api_key or GOOGLE_API_KEY
model = request.model
output_file = request.output_file
# Validate that chunks file exists
if not os.path.exists(chunks_path):
raise HTTPException(status_code=404, detail=f"Chunks file not found at {chunks_path}")
# Validate API key is available
if not api_key:
raise HTTPException(status_code=400, detail="No API key provided")
# Start background generation task
background_tasks.add_task(
background_generation_task,
chunks_path,
model,
output_file,
api_key
)
# Return initial status
return get_generation_status()
@app.get("/status", response_model=GenerationStatus)
async def get_generation_status():
"""Get the current status of the question generation process"""
# Calculate progress percentage
total = generation_status["total_chunks"]
processed = generation_status["processed_chunks"]
progress_percentage = (processed / total * 100) if total > 0 else 0
# Calculate estimated time remaining
etr = None
if (generation_status["is_running"] and
generation_status["start_time"] and
processed > 0):
start_time = datetime.fromisoformat(generation_status["start_time"])
time_elapsed = (datetime.utcnow() - start_time).total_seconds()
time_per_chunk = time_elapsed / processed
remaining_chunks = total - processed
etr_seconds = time_per_chunk * remaining_chunks
etr = f"{int(etr_seconds // 60)}m {int(etr_seconds % 60)}s"
# Return formatted status
return GenerationStatus(
is_running=generation_status["is_running"],
total_chunks=total,
processed_chunks=processed,
current_chunk_id=generation_status["current_chunk_id"],
progress_percentage=round(progress_percentage, 2),
start_time=generation_status["start_time"],
end_time=generation_status["end_time"],
estimated_time_remaining=etr,
errors=generation_status["errors"],
result_file=generation_status["result_file"]
)
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"name": "Vaccine Question Generator API",
"description": "API for generating question-answer pairs from vaccine guide chunks",
"endpoints": [
{"path": "/", "method": "GET", "description": "This information page"},
{"path": "/generate", "method": "POST", "description": "Start question generation process"},
{"path": "/status", "method": "GET", "description": "Get current generation status"}
]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)