Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| import json | |
| from dotenv import load_dotenv | |
| import time | |
| import uuid | |
| from typing import List, Dict, Optional | |
| from datetime import datetime | |
| from huggingface_hub import HfApi # For file persistence in Spaces | |
| import os | |
| import threading | |
| import glob | |
| import random | |
| from langchain_google_genai import GoogleGenerativeAI | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| app = FastAPI() | |
| # Global variables to track generation status | |
| generation_status = { | |
| "is_running": False, | |
| "start_time": None, | |
| "processed_chunks": 0, | |
| "total_chunks": 0, | |
| "questions_generated": 0, | |
| "completed": False, | |
| "result_file": None, | |
| "progress_file": None, # New: track progress file | |
| "error": None, | |
| "current_api_key_index": 0, # New: track current API key | |
| "failed_chunks": [], # New: track failed chunks for retry | |
| "partial_results": [] # New: store partial results | |
| } | |
| generation_lock = threading.Lock() | |
| def get_api_keys() -> List[str]: | |
| """ | |
| Get all available Google API keys from environment variables. | |
| Supports GOOGLE_API_KEY, GOOGLE_API_KEY_1, GOOGLE_API_KEY_2, etc. | |
| """ | |
| api_keys = [] | |
| # Check for primary key | |
| primary_key = os.getenv("GOOGLE_API_KEY") | |
| if primary_key: | |
| api_keys.append(primary_key) | |
| # Check for numbered keys | |
| i = 1 | |
| while True: | |
| key = os.getenv(f"GOOGLE_API_KEY_{i}") | |
| if key: | |
| api_keys.append(key) | |
| i += 1 | |
| else: | |
| break | |
| if not api_keys: | |
| raise ValueError("No Google API keys found in environment variables") | |
| return api_keys | |
| def get_next_api_key() -> tuple[str, int]: | |
| """ | |
| Get the next API key in rotation and update the current index. | |
| Returns tuple of (api_key, key_index) | |
| """ | |
| global generation_status | |
| api_keys = get_api_keys() | |
| with generation_lock: | |
| current_index = generation_status["current_api_key_index"] | |
| next_index = (current_index + 1) % len(api_keys) | |
| generation_status["current_api_key_index"] = next_index | |
| return api_keys[next_index], next_index | |
| def save_progress_file(): | |
| """ | |
| Save current progress to a file that can be downloaded at any time. | |
| """ | |
| global generation_status | |
| with generation_lock: | |
| progress_data = { | |
| "generation_info": { | |
| "status": "in_progress" if generation_status["is_running"] else "completed", | |
| "start_time": generation_status["start_time"], | |
| "processed_chunks": generation_status["processed_chunks"], | |
| "total_chunks": generation_status["total_chunks"], | |
| "questions_generated": generation_status["questions_generated"], | |
| "completed": generation_status["completed"], | |
| "current_time": datetime.utcnow().isoformat(), | |
| "failed_chunks": generation_status["failed_chunks"].copy(), | |
| "error": generation_status["error"] | |
| }, | |
| "partial_dataset": { | |
| "dataset_info": { | |
| "title": "Vaccine Guide Question-Answer Dataset (Partial)", | |
| "description": "Partial dataset of question-answer pairs generated from a vaccine guide.", | |
| "version": "1.1.0", | |
| "created_date": generation_status["start_time"], | |
| "source": "Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.pdf", | |
| "generated_by": "Gemini API", | |
| "total_questions": len(generation_status["partial_results"]), | |
| "intended_use": "Fine-tuning medical language models for knowledge recall and reasoning", | |
| "note": "This is a partial dataset. Generation may still be in progress." | |
| }, | |
| "questions": generation_status["partial_results"].copy() | |
| } | |
| } | |
| # Save progress file | |
| progress_filename = f"vaccine_questions_progress_{int(time.time())}.json" | |
| generation_status["progress_file"] = progress_filename | |
| try: | |
| with open(f"./{progress_filename}", 'w', encoding='utf-8') as f: | |
| json.dump(progress_data, f, indent=4, ensure_ascii=False) | |
| print(f"Progress saved to {progress_filename}") | |
| except Exception as e: | |
| print(f"Error saving progress file: {e}") | |
| def estimate_difficulty(question: str, q_type: str) -> str: | |
| """ | |
| Estimate question difficulty based on type and content. | |
| Args: | |
| question (str): The question text. | |
| q_type (str): Question type (factual, conceptual, applied). | |
| Returns: | |
| str: Difficulty level (easy, medium, hard). | |
| """ | |
| if q_type == "factual": | |
| return "easy" | |
| elif q_type == "conceptual": | |
| return "medium" | |
| return "hard" # applied | |
| def generate_questions_for_chunk(chunk: str, chunk_id: int, model="gemini-2.0-flash", max_retries=3) -> List[Dict]: | |
| """ | |
| Generate French questions for a given document chunk using the Gemini API. | |
| Now includes retry logic with different API keys. | |
| """ | |
| prompt = f""" | |
| À partir du texte suivant d'un guide sur les vaccins en français, générez 3 questions variées (factual, conceptual, applied) qui couvrent le contenu de manière exhaustive. | |
| Fournissez uniquement les questions, sans réponses, en français. Retournez le résultat au format JSON, entouré de ```json\n...\n```. | |
| Texte : {chunk} | |
| Exemple de sortie : | |
| ```json | |
| [ | |
| {{ | |
| "question": "Combien de structures sanitaires de proximité sont impliquées dans le suivi de la vaccination ?", | |
| "type": "factual" | |
| }}, | |
| {{ | |
| "question": "Quel est l'impact de la réglementation de la vaccination sur la couverture vaccinale ?", | |
| "type": "conceptual" | |
| }}, | |
| {{ | |
| "question": "Quelles seraient les conséquences si les établissements privés ne suivaient plus la réglementation vaccinale ?", | |
| "type": "applied" | |
| }} | |
| ] | |
| ``` | |
| """ | |
| last_error = None | |
| for attempt in range(max_retries): | |
| try: | |
| # Get next API key for this attempt | |
| api_key, key_index = get_next_api_key() | |
| print(f"Chunk {chunk_id}, attempt {attempt + 1}: Using API key index {key_index}") | |
| llm = GoogleGenerativeAI( | |
| model=model, | |
| google_api_key=api_key | |
| ) | |
| response = llm.invoke(prompt) | |
| questions_text = str(response) # Convert response to string | |
| # Strip Markdown code fences | |
| if questions_text.startswith("```json\n") and questions_text.endswith("\n```"): | |
| questions_text = questions_text[7:-4].strip() | |
| elif questions_text.startswith("```") and questions_text.endswith("```"): | |
| questions_text = questions_text[3:-3].strip() | |
| if not questions_text: | |
| raise ValueError(f"Empty response for chunk {chunk_id}") | |
| questions = json.loads(questions_text) | |
| formatted_questions = [] | |
| for q in questions: | |
| question_id = str(uuid.uuid4()) | |
| difficulty = estimate_difficulty(q["question"], q["type"]) | |
| formatted_questions.append({ | |
| "question_id": question_id, | |
| "chunk_id": chunk_id, | |
| "chunk_text": chunk, | |
| "question": q["question"], | |
| "type": q["type"], | |
| "difficulty": difficulty, | |
| "training_purpose": "Knowledge Recall" if q["type"] == "factual" else "Reasoning", | |
| "validated": False, | |
| "api_key_used": key_index, # Track which key was used | |
| "generation_attempt": attempt + 1 | |
| }) | |
| # Update the global status and add to partial results | |
| with generation_lock: | |
| generation_status["questions_generated"] += len(formatted_questions) | |
| generation_status["partial_results"].extend(formatted_questions) | |
| # Save progress after each successful chunk | |
| save_progress_file() | |
| print(f"Successfully generated {len(formatted_questions)} questions for chunk {chunk_id}") | |
| return formatted_questions | |
| except Exception as e: | |
| last_error = e | |
| print(f"Attempt {attempt + 1} failed for chunk {chunk_id}: {e}") | |
| # If this is not the last attempt, wait before retrying | |
| if attempt < max_retries - 1: | |
| wait_time = (attempt + 1) * 5 # Increasing wait time | |
| print(f"Waiting {wait_time} seconds before retry...") | |
| time.sleep(wait_time) | |
| continue | |
| # All attempts failed | |
| print(f"All {max_retries} attempts failed for chunk {chunk_id}. Last error: {last_error}") | |
| # Add to failed chunks list | |
| with generation_lock: | |
| generation_status["failed_chunks"].append({ | |
| "chunk_id": chunk_id, | |
| "error": str(last_error), | |
| "attempts": max_retries | |
| }) | |
| return [] | |
| def generate_questions_in_background(chunks: List[str]): | |
| """ | |
| Generate questions in a background thread and update status. | |
| Enhanced with better error handling and progress tracking. | |
| """ | |
| global generation_status | |
| try: | |
| all_questions = [] | |
| with generation_lock: | |
| generation_status["total_chunks"] = len(chunks) | |
| generation_status["processed_chunks"] = 0 | |
| generation_status["questions_generated"] = 0 | |
| generation_status["partial_results"] = [] | |
| generation_status["failed_chunks"] = [] | |
| # Save initial progress file | |
| save_progress_file() | |
| for i, chunk in enumerate(chunks): | |
| print(f"Processing chunk {i+1}/{len(chunks)}...") | |
| questions = generate_questions_for_chunk(chunk, i) | |
| if questions: # Only add if generation was successful | |
| all_questions.extend(questions) | |
| with generation_lock: | |
| generation_status["processed_chunks"] = i + 1 | |
| # Rate limiting - slightly randomized to avoid hitting limits | |
| sleep_time = random.uniform(8, 11) # Random between 8-11 seconds | |
| time.sleep(sleep_time) | |
| # Create final dataset | |
| dataset = { | |
| "dataset_info": { | |
| "title": "Vaccine Guide Question-Answer Dataset", | |
| "description": "A dataset of question-answer pairs generated from a vaccine guide for AI language model training.", | |
| "version": "1.1.0", | |
| "created_date": datetime.utcnow().isoformat(), | |
| "source": "Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.pdf", | |
| "generated_by": "Gemini API", | |
| "total_questions": len(all_questions), | |
| "intended_use": "Fine-tuning medical language models for knowledge recall and reasoning", | |
| "total_chunks_processed": len(chunks), | |
| "successful_chunks": len(chunks) - len(generation_status["failed_chunks"]), | |
| "failed_chunks": len(generation_status["failed_chunks"]), | |
| "failed_chunk_details": generation_status["failed_chunks"].copy() | |
| }, | |
| "questions": all_questions | |
| } | |
| # Save the final dataset | |
| filename = f"vaccine_questions_final_{int(time.time())}.json" | |
| with open(f"./{filename}", 'w', encoding='utf-8') as f: | |
| json.dump(dataset, f, indent=4, ensure_ascii=False) | |
| # Update status to completed | |
| with generation_lock: | |
| generation_status["completed"] = True | |
| generation_status["is_running"] = False | |
| generation_status["result_file"] = filename | |
| # Save final progress file | |
| save_progress_file() | |
| success_rate = (len(chunks) - len(generation_status["failed_chunks"])) / len(chunks) * 100 | |
| print(f"Generation completed! Success rate: {success_rate:.1f}% ({len(all_questions)} questions generated)") | |
| except Exception as e: | |
| print(f"Error in background generation: {e}") | |
| with generation_lock: | |
| generation_status["error"] = str(e) | |
| generation_status["is_running"] = False | |
| # Save progress even if there was an error | |
| save_progress_file() | |
| def save_dataset_to_space(dataset: Dict, filename: str): | |
| """ | |
| Save dataset to a file in the Space's persistent storage | |
| """ | |
| persistent_path = f"./{filename}" | |
| with open(persistent_path, 'w', encoding='utf-8') as f: | |
| json.dump(dataset, f, indent=4, ensure_ascii=False) | |
| print(f"Dataset saved to {persistent_path}") | |
| async def generate_questions(): | |
| """ | |
| Endpoint to generate questions from all JSON files in the data folder | |
| Enhanced with multi-key support validation | |
| """ | |
| global generation_status | |
| # Check if generation is already running | |
| with generation_lock: | |
| if generation_status["is_running"]: | |
| return { | |
| "status": "running", | |
| "message": "Generation already in progress", | |
| "current_status": generation_status | |
| } | |
| try: | |
| # Validate API keys before starting | |
| api_keys = get_api_keys() | |
| print(f"Found {len(api_keys)} API keys for rotation") | |
| # Reset status | |
| with generation_lock: | |
| generation_status["is_running"] = True | |
| generation_status["start_time"] = datetime.utcnow().isoformat() | |
| generation_status["processed_chunks"] = 0 | |
| generation_status["questions_generated"] = 0 | |
| generation_status["completed"] = False | |
| generation_status["result_file"] = None | |
| generation_status["progress_file"] = None | |
| generation_status["error"] = None | |
| generation_status["current_api_key_index"] = 0 | |
| generation_status["failed_chunks"] = [] | |
| generation_status["partial_results"] = [] | |
| # Load all JSON files from data folder | |
| json_files = glob.glob("./chunk/*.json") | |
| if not json_files: | |
| raise HTTPException(status_code=404, detail="No JSON files found in chunk folder") | |
| all_chunks = [] | |
| for json_file in json_files: | |
| with open(json_file, "r", encoding="utf-8") as f: | |
| chunks_data = json.load(f) | |
| if isinstance(chunks_data, list): | |
| # If it's a list of chunks | |
| for chunk in chunks_data: | |
| if isinstance(chunk, dict) and "text" in chunk: | |
| all_chunks.append(chunk["text"]) | |
| elif isinstance(chunk, str): | |
| all_chunks.append(chunk) | |
| elif isinstance(chunks_data, dict): | |
| # If it's a dict, try to extract text content | |
| if "text" in chunks_data: | |
| all_chunks.append(chunks_data["text"]) | |
| elif "content" in chunks_data: | |
| all_chunks.append(chunks_data["content"]) | |
| if not all_chunks: | |
| raise HTTPException(status_code=404, detail="No text content found in JSON files") | |
| # Start generation in background thread | |
| thread = threading.Thread(target=generate_questions_in_background, args=(all_chunks,)) | |
| thread.daemon = True | |
| thread.start() | |
| return { | |
| "status": "started", | |
| "message": f"Question generation started for {len(json_files)} JSON files with {len(all_chunks)} chunks", | |
| "api_keys_available": len(api_keys), | |
| "current_status": generation_status | |
| } | |
| except Exception as e: | |
| with generation_lock: | |
| generation_status["is_running"] = False | |
| generation_status["error"] = str(e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_generation_status(): | |
| """ | |
| Endpoint to check the current status of generation | |
| Enhanced with more detailed status information | |
| """ | |
| with generation_lock: | |
| status_copy = generation_status.copy() | |
| # Calculate additional metrics | |
| if status_copy["total_chunks"] > 0: | |
| progress_percentage = (status_copy["processed_chunks"] / status_copy["total_chunks"]) * 100 | |
| status_copy["progress_percentage"] = round(progress_percentage, 2) | |
| else: | |
| status_copy["progress_percentage"] = 0 | |
| # Add estimated time remaining if generation is running | |
| if status_copy["is_running"] and status_copy["start_time"] and status_copy["processed_chunks"] > 0: | |
| start_time = datetime.fromisoformat(status_copy["start_time"].replace('Z', '+00:00')) | |
| elapsed_time = (datetime.utcnow() - start_time.replace(tzinfo=None)).total_seconds() | |
| chunks_per_second = status_copy["processed_chunks"] / elapsed_time if elapsed_time > 0 else 0 | |
| if chunks_per_second > 0: | |
| remaining_chunks = status_copy["total_chunks"] - status_copy["processed_chunks"] | |
| estimated_remaining_seconds = remaining_chunks / chunks_per_second | |
| status_copy["estimated_remaining_minutes"] = round(estimated_remaining_seconds / 60, 2) | |
| else: | |
| status_copy["estimated_remaining_minutes"] = None | |
| return status_copy | |
| async def download_progress(): | |
| """ | |
| New endpoint to download current progress at any time | |
| """ | |
| global generation_status | |
| # Force save current progress | |
| save_progress_file() | |
| with generation_lock: | |
| progress_file = generation_status["progress_file"] | |
| if progress_file and os.path.exists(f"./{progress_file}"): | |
| return FileResponse(f"./{progress_file}", media_type="application/json", filename=progress_file) | |
| else: | |
| raise HTTPException(status_code=404, detail="No progress file available") | |
| async def download_file(filename: str): | |
| """ | |
| Endpoint to download generated files | |
| Enhanced with better error handling | |
| """ | |
| file_path = f"./{filename}" | |
| if os.path.exists(file_path): | |
| return FileResponse(file_path, media_type="application/json", filename=filename) | |
| raise HTTPException(status_code=404, detail=f"File {filename} not found") | |
| async def retry_failed_chunks(): | |
| """ | |
| New endpoint to retry only the failed chunks | |
| """ | |
| global generation_status | |
| with generation_lock: | |
| if generation_status["is_running"]: | |
| return { | |
| "status": "error", | |
| "message": "Cannot retry while generation is running" | |
| } | |
| failed_chunks = generation_status["failed_chunks"].copy() | |
| if not failed_chunks: | |
| return { | |
| "status": "success", | |
| "message": "No failed chunks to retry" | |
| } | |
| # This would require implementing the retry logic | |
| # For now, just return the failed chunks info | |
| return { | |
| "status": "info", | |
| "message": f"Found {len(failed_chunks)} failed chunks", | |
| "failed_chunks": failed_chunks, | |
| "note": "Retry functionality can be implemented based on requirements" | |
| } | |
| async def get_api_keys_status(): | |
| """ | |
| New endpoint to check API keys status | |
| """ | |
| try: | |
| api_keys = get_api_keys() | |
| return { | |
| "status": "success", | |
| "total_keys": len(api_keys), | |
| "current_key_index": generation_status["current_api_key_index"], | |
| "message": f"{len(api_keys)} API keys configured for rotation" | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "message": str(e) | |
| } | |
| async def root(): | |
| """ | |
| Root endpoint that serves the HTML UI from the index.html file. | |
| """ | |
| print("Serving index.html") # Debug log to confirm serving | |
| return FileResponse("./index.html", media_type="text/html") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |