kofdai's picture
Upload folder using huggingface_hub
594ed40 verified
"""
NullAI Database Enrichment API - Rebuilt Version
知識ベースの自動拡充機能を提供するAPI
Complete rebuild to avoid MLX dependency issues while preserving all functionality
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional, List
import logging
import os
import asyncio
from datetime import datetime
import uuid
import json
from null_ai.coordinate_estimator import CoordinateEstimator
from null_ai.iath_writer import IathWriter, create_tile_from_ai_output
from backend.app.config import app_model_router
router = APIRouter()
logger = logging.getLogger(__name__)
# ===== Pydantic Models =====
class AIEnrichmentRequest(BaseModel):
domain_id: str = Field(..., description="ドメインID")
num_questions: int = Field(10, ge=1, le=100, description="生成する質問数")
focus_areas: Optional[List[str]] = Field(None, description="重点領域")
prompt_model_id: Optional[str] = Field(None, description="プロンプト生成用モデルID")
answer_model_id: Optional[str] = Field(None, description="回答生成用モデルID")
class WebEnrichmentRequest(BaseModel):
query: str = Field(..., description="検索クエリ")
domain_id: str = Field(..., description="ドメインID")
max_results: int = Field(5, ge=1, le=20, description="最大検索結果数")
class EnrichmentStatusResponse(BaseModel):
is_running: bool
progress: float
current_question: int
total_questions: int
generated_tiles: int
start_time: Optional[str]
domain_id: Optional[str]
class EnrichmentResultResponse(BaseModel):
success: bool
questions_generated: Optional[int] = None
tiles_created: Optional[int] = None
domain_id: Optional[str] = None
error: Optional[str] = None
# ===== Global State Management =====
class EnrichmentState:
"""Global enrichment state tracker"""
def __init__(self):
self.is_running = False
self.progress = 0.0
self.current_question = 0
self.total_questions = 0
self.generated_tiles = 0
self.start_time = None
self.domain_id = None
self.stop_requested = False
def reset(self):
"""Reset to initial state"""
self.is_running = False
self.progress = 0.0
self.current_question = 0
self.total_questions = 0
self.generated_tiles = 0
self.start_time = None
self.domain_id = None
self.stop_requested = False
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API response"""
return {
"is_running": self.is_running,
"progress": self.progress,
"current_question": self.current_question,
"total_questions": self.total_questions,
"generated_tiles": self.generated_tiles,
"start_time": self.start_time,
"domain_id": self.domain_id
}
# Global state instance
enrichment_state = EnrichmentState()
# ===== Safe Model Selection =====
def get_safe_model(model_id: Optional[str] = None):
"""
Get a safe model that won't cause MLX errors
Prioritizes GGUF and Ollama models
"""
try:
# If model_id specified, try to get it
if model_id:
model = app_model_router.config_manager.get_model_config(model_id)
if model:
logger.info(f"Selected specified model: {model.model_id} (provider: {model.provider})")
return model
logger.warning(f"Requested model {model_id} not found, falling back to safe default")
# Priority 1: GGUF models (most reliable, no external dependencies)
for model in app_model_router.config_manager.models.values():
if model.provider == "gguf":
logger.info(f"Auto-selected GGUF model: {model.model_id}")
return model
# Priority 2: Ollama models (reliable, local)
for model in app_model_router.config_manager.models.values():
if model.provider == "ollama":
logger.info(f"Auto-selected Ollama model: {model.model_id}")
return model
# Priority 3: HuggingFace models (may require download but generally work)
for model in app_model_router.config_manager.models.values():
if model.provider == "huggingface":
logger.info(f"Auto-selected HuggingFace model: {model.model_id}")
return model
# Last resort: any available model (excluding MLX)
for model in app_model_router.config_manager.models.values():
if model.provider != "mlx": # Explicitly avoid MLX
logger.warning(f"Using last resort model: {model.model_id} (provider: {model.provider})")
return model
logger.error("No safe model available")
return None
except Exception as e:
logger.error(f"Error in get_safe_model: {e}", exc_info=True)
return None
# ===== Safe LLM Inference =====
async def safe_llm_inference(model, prompt: str, temperature: float = 0.7) -> Dict[str, Any]:
"""
Safely perform LLM inference with robust error handling
"""
try:
if not model:
logger.error("No model provided for inference")
return {"response": "", "error": "No model available"}
logger.debug(f"Performing inference with model: {model.model_id} (provider: {model.provider})")
# Call the model router's inference method
result = await app_model_router._perform_llm_inference(model, prompt, temperature=temperature)
# Ensure we have a response
if not result or "response" not in result:
logger.warning(f"Empty or invalid response from model {model.model_id}")
return {"response": "", "confidence": 0.0}
return result
except Exception as e:
logger.error(f"Error during LLM inference: {e}", exc_info=True)
return {"response": "", "error": str(e)}
# ===== Question Generation =====
async def generate_enrichment_questions(
domain_id: str,
num_questions: int,
focus_areas: Optional[List[str]],
model
) -> List[str]:
"""
Generate enrichment questions using the specified model
"""
try:
logger.info(f"Generating {num_questions} questions for domain '{domain_id}'")
# Build focus areas text
focus_text = ""
if focus_areas:
focus_text = f"\nFocus on these specific areas: {', '.join(focus_areas)}"
# Construct prompt
prompt = f"""You are an expert knowledge curator in the domain of {domain_id}.
Your task is to generate {num_questions} diverse, high-quality questions that would enrich a knowledge base in this domain.
Guidelines:
1. Cover a wide range of topics within the domain
2. Include questions at different complexity levels (basic, intermediate, advanced)
3. Focus on practical applications and edge cases
4. Avoid overly broad or trivial questions
5. Each question should elicit detailed, informative answers{focus_text}
Output format:
Return ONLY a JSON array of questions, like this:
["Question 1?", "Question 2?", ...]
Domain: {domain_id}
Number of questions: {num_questions}
Generate the questions now:"""
# Generate using LLM
result = await safe_llm_inference(model, prompt, temperature=0.7)
response = result.get("response", "")
if not response:
logger.error("Empty response from LLM for question generation")
return []
logger.debug(f"Received response (first 200 chars): {response[:200]}")
# Extract questions from response
questions = extract_questions_from_response(response)
# Remove duplicates while preserving order
questions = list(dict.fromkeys(questions))
logger.info(f"Successfully generated {len(questions)} unique questions")
return questions[:num_questions]
except Exception as e:
logger.error(f"Failed to generate questions: {e}", exc_info=True)
return []
def extract_questions_from_response(response: str) -> List[str]:
"""
Extract questions list from LLM response
Handles JSON format and fallback to text parsing
"""
if not response or not response.strip():
logger.error("Empty response provided")
return []
try:
# Try to find JSON array
json_str = ""
if "```json" in response:
json_start = response.find("```json") + 7
json_end = response.find("```", json_start)
json_str = response[json_start:json_end].strip()
elif "```" in response:
json_start = response.find("```") + 3
json_end = response.find("```", json_start)
json_str = response[json_start:json_end].strip()
elif "[" in response and "]" in response:
json_start = response.find("[")
json_end = response.rfind("]") + 1
json_str = response[json_start:json_end]
else:
json_str = response.strip()
if not json_str:
logger.warning("Could not extract JSON, trying fallback parsing")
return fallback_parse_questions(response)
# Parse JSON
questions = json.loads(json_str)
if isinstance(questions, list):
return [q for q in questions if isinstance(q, str) and len(q) > 5]
else:
logger.warning(f"Expected list, got {type(questions)}")
return fallback_parse_questions(response)
except json.JSONDecodeError as e:
logger.warning(f"JSON parse error: {e}, trying fallback parsing")
return fallback_parse_questions(response)
def fallback_parse_questions(response: str) -> List[str]:
"""
Fallback: parse questions from plain text
"""
logger.info("Using fallback text parsing for questions")
questions = []
for line in response.split('\n'):
line = line.strip()
if len(line) > 5 and '?' in line:
# Remove list markers (1. 2. - *)
cleaned = line.lstrip('0123456789.-* ').strip()
if cleaned:
questions.append(cleaned)
logger.info(f"Fallback parsing extracted {len(questions)} questions")
return questions
# ===== Web Enrichment =====
async def web_search_brave(api_key: str, query: str, max_results: int) -> List[Dict[str, str]]:
"""
Search using Brave Search API
"""
try:
import aiohttp
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"Accept": "application/json",
"X-Subscription-Token": api_key
}
params = {
"q": query,
"count": max_results
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, params=params, timeout=30) as response:
if response.status == 200:
data = await response.json()
results = []
for item in data.get("web", {}).get("results", []):
results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("description", "")
})
return results
else:
logger.error(f"Brave Search API error: {response.status}")
return []
except Exception as e:
logger.error(f"Brave Search failed: {e}", exc_info=True)
return []
# ===== API Endpoints =====
@router.post("/ai/start", response_model=EnrichmentResultResponse)
async def start_ai_enrichment(request: AIEnrichmentRequest, background_tasks: BackgroundTasks):
"""
Start AI-driven enrichment process
"""
try:
# Check if already running
if enrichment_state.is_running:
raise HTTPException(status_code=409, detail="Enrichment is already running")
# Get safe models
prompt_model = get_safe_model(request.prompt_model_id)
if not prompt_model:
raise HTTPException(status_code=400, detail="No safe model available for prompt generation")
answer_model = get_safe_model(request.answer_model_id)
if not answer_model:
answer_model = prompt_model # Use same model for answers if not specified
logger.info(f"Starting enrichment with prompt_model={prompt_model.model_id}, answer_model={answer_model.model_id}")
# Add background task
background_tasks.add_task(
run_ai_enrichment_task,
request.domain_id,
request.num_questions,
request.focus_areas,
prompt_model,
answer_model
)
return EnrichmentResultResponse(success=True, domain_id=request.domain_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to start AI enrichment: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
async def run_ai_enrichment_task(
domain_id: str,
num_questions: int,
focus_areas: Optional[List[str]],
prompt_model,
answer_model
):
"""
Background task for AI enrichment
"""
try:
# Initialize state
enrichment_state.reset()
enrichment_state.is_running = True
enrichment_state.domain_id = domain_id
enrichment_state.total_questions = num_questions
enrichment_state.start_time = datetime.utcnow().isoformat()
# Step 1: Generate questions
logger.info("Step 1: Generating questions...")
questions = await generate_enrichment_questions(
domain_id, num_questions, focus_areas, prompt_model
)
if not questions:
logger.error("No questions generated")
enrichment_state.is_running = False
return
enrichment_state.total_questions = len(questions)
# Step 2: Generate answers and save tiles
logger.info(f"Step 2: Generating answers for {len(questions)} questions...")
for i, question in enumerate(questions):
# Check if stop requested
if enrichment_state.stop_requested:
logger.info("Enrichment stop requested, aborting")
break
logger.info(f"Processing question {i+1}/{len(questions)}: {question[:50]}...")
enrichment_state.current_question = i + 1
enrichment_state.progress = (i / len(questions)) * 100
try:
# Generate answer
answer_result = await safe_llm_inference(answer_model, question, temperature=0.5)
answer = answer_result.get("response", "")
if not answer:
logger.warning(f"Empty answer for question: {question}")
continue
confidence = answer_result.get("confidence", 0.7)
# Estimate coordinates
coord_prompt = f"""Given this knowledge:
Question: {question}
Answer: {answer[:500]}
Estimate 6D coordinates [x, y, z, c, g, v] where:
- x, y, z: domain-specific 3D space (0.0-1.0 each)
- c: certainty (0.0-1.0)
- g: granularity (0.0-1.0)
- v: verification level (0.0-1.0)
Return only the coordinates as JSON array: [x, y, z, c, g, v]"""
coord_result = await safe_llm_inference(prompt_model, coord_prompt, temperature=0.3)
coord_response = coord_result.get("response", "")
# Parse coordinates
coordinates = [0.5, 0.5, 0.5, 0.7, 0.6, 0.8] # Default
if coord_response:
try:
# Extract JSON array from response
if "[" in coord_response and "]" in coord_response:
start = coord_response.find("[")
end = coord_response.rfind("]") + 1
coord_json = coord_response[start:end]
parsed_coords = json.loads(coord_json)
if isinstance(parsed_coords, list) and len(parsed_coords) == 6:
coordinates = parsed_coords
except Exception as e:
logger.warning(f"Could not parse coordinates: {e}")
# Create and save tile
tile = create_tile_from_ai_output(
knowledge_id=f"ai_enrich_{uuid.uuid4().hex}",
topic=question[:100],
prompt=question,
response=answer,
coordinates=coordinates,
confidence=confidence,
domain_id=domain_id,
source="ai_enrichment"
)
success = app_model_router.iath_writer.append_tile(tile)
if success:
enrichment_state.generated_tiles += 1
logger.info(f"Saved tile {enrichment_state.generated_tiles}: {question[:50]}...")
else:
logger.warning(f"Failed to save tile for: {question}")
except Exception as e:
logger.error(f"Error processing question: {e}", exc_info=True)
continue
# Complete
enrichment_state.progress = 100.0
enrichment_state.current_question = len(questions)
logger.info(f"Enrichment completed: {enrichment_state.generated_tiles} tiles created")
except Exception as e:
logger.error(f"AI enrichment task failed: {e}", exc_info=True)
finally:
enrichment_state.is_running = False
@router.post("/web/start", response_model=EnrichmentResultResponse)
async def start_web_enrichment(request: WebEnrichmentRequest):
"""
Start web-based enrichment using Brave Search
"""
try:
brave_api_key = os.getenv("BRAVE_SEARCH_API_KEY")
if not brave_api_key:
raise HTTPException(status_code=400, detail="Brave Search API key not configured")
# Perform search
results = await web_search_brave(brave_api_key, request.query, request.max_results)
if not results:
raise HTTPException(status_code=404, detail="No search results found")
# Create tiles from results
tiles_created = 0
for result in results:
try:
tile_content = f"""Title: {result['title']}
Source: {result['url']}
Summary:
{result['snippet']}"""
# Default coordinates for web results
coordinates = [0.5, 0.5, 0.5, 0.6, 0.5, 0.7]
tile = create_tile_from_ai_output(
knowledge_id=f"web_{uuid.uuid4().hex}",
topic=result['title'],
prompt=request.query,
response=tile_content,
coordinates=coordinates,
confidence=0.6,
domain_id=request.domain_id,
source="web_search"
)
success = app_model_router.iath_writer.append_tile(tile)
if success:
tiles_created += 1
except Exception as e:
logger.error(f"Error creating tile from search result: {e}")
continue
return EnrichmentResultResponse(
success=True,
tiles_created=tiles_created,
domain_id=request.domain_id
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Web enrichment failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status", response_model=EnrichmentStatusResponse)
async def get_enrichment_status():
"""
Get current enrichment status
"""
try:
return EnrichmentStatusResponse(**enrichment_state.to_dict())
except Exception as e:
logger.error(f"Failed to get status: {e}", exc_info=True)
# Return safe default
return EnrichmentStatusResponse(
is_running=False,
progress=0.0,
current_question=0,
total_questions=0,
generated_tiles=0,
start_time=None,
domain_id=None
)
@router.post("/stop")
async def stop_enrichment():
"""
Stop the current enrichment process
"""
if not enrichment_state.is_running:
raise HTTPException(status_code=400, detail="No enrichment is currently running")
enrichment_state.stop_requested = True
logger.info("Enrichment stop requested")
return {"success": True, "message": "Enrichment stop requested"}
@router.post("/ai/generate-questions", response_model=Dict[str, Any])
async def generate_questions_only(request: AIEnrichmentRequest):
"""
Generate questions only (no answers)
"""
try:
model = get_safe_model(request.prompt_model_id)
if not model:
raise HTTPException(status_code=400, detail="No safe model available")
questions = await generate_enrichment_questions(
request.domain_id,
request.num_questions,
request.focus_areas,
model
)
return {
"questions": questions,
"domain_id": request.domain_id,
"count": len(questions)
}
except Exception as e:
logger.error(f"Failed to generate questions: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/config")
async def get_enrichment_config():
"""
Get enrichment configuration
"""
brave_api_key = os.getenv("BRAVE_SEARCH_API_KEY")
# Get only safe models (no MLX)
safe_models = []
for model in app_model_router.config_manager.models.values():
if model.provider != "mlx": # Exclude MLX models
safe_models.append({
"id": model.model_id,
"name": model.model_name,
"provider": model.provider
})
master_model_id = app_model_router.master_model.model_id if app_model_router.master_model else None
return {
"ai_enrichment": {
"available": len(safe_models) > 0,
"available_models": safe_models,
"master_model_id": master_model_id
},
"web_enrichment": {
"available": bool(brave_api_key),
"search_engine": "brave",
"api_key_configured": bool(brave_api_key)
}
}