|
|
""" |
|
|
NullAI Database Enrichment API - Rebuilt Version |
|
|
|
|
|
知識ベースの自動拡充機能を提供するAPI |
|
|
Complete rebuild to avoid MLX dependency issues while preserving all functionality |
|
|
""" |
|
|
from fastapi import APIRouter, HTTPException, BackgroundTasks |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Dict, Any, Optional, List |
|
|
import logging |
|
|
import os |
|
|
import asyncio |
|
|
from datetime import datetime |
|
|
import uuid |
|
|
import json |
|
|
|
|
|
from null_ai.coordinate_estimator import CoordinateEstimator |
|
|
from null_ai.iath_writer import IathWriter, create_tile_from_ai_output |
|
|
from backend.app.config import app_model_router |
|
|
|
|
|
router = APIRouter() |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class AIEnrichmentRequest(BaseModel): |
|
|
domain_id: str = Field(..., description="ドメインID") |
|
|
num_questions: int = Field(10, ge=1, le=100, description="生成する質問数") |
|
|
focus_areas: Optional[List[str]] = Field(None, description="重点領域") |
|
|
prompt_model_id: Optional[str] = Field(None, description="プロンプト生成用モデルID") |
|
|
answer_model_id: Optional[str] = Field(None, description="回答生成用モデルID") |
|
|
|
|
|
class WebEnrichmentRequest(BaseModel): |
|
|
query: str = Field(..., description="検索クエリ") |
|
|
domain_id: str = Field(..., description="ドメインID") |
|
|
max_results: int = Field(5, ge=1, le=20, description="最大検索結果数") |
|
|
|
|
|
class EnrichmentStatusResponse(BaseModel): |
|
|
is_running: bool |
|
|
progress: float |
|
|
current_question: int |
|
|
total_questions: int |
|
|
generated_tiles: int |
|
|
start_time: Optional[str] |
|
|
domain_id: Optional[str] |
|
|
|
|
|
class EnrichmentResultResponse(BaseModel): |
|
|
success: bool |
|
|
questions_generated: Optional[int] = None |
|
|
tiles_created: Optional[int] = None |
|
|
domain_id: Optional[str] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
class EnrichmentState: |
|
|
"""Global enrichment state tracker""" |
|
|
def __init__(self): |
|
|
self.is_running = False |
|
|
self.progress = 0.0 |
|
|
self.current_question = 0 |
|
|
self.total_questions = 0 |
|
|
self.generated_tiles = 0 |
|
|
self.start_time = None |
|
|
self.domain_id = None |
|
|
self.stop_requested = False |
|
|
|
|
|
def reset(self): |
|
|
"""Reset to initial state""" |
|
|
self.is_running = False |
|
|
self.progress = 0.0 |
|
|
self.current_question = 0 |
|
|
self.total_questions = 0 |
|
|
self.generated_tiles = 0 |
|
|
self.start_time = None |
|
|
self.domain_id = None |
|
|
self.stop_requested = False |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert to dictionary for API response""" |
|
|
return { |
|
|
"is_running": self.is_running, |
|
|
"progress": self.progress, |
|
|
"current_question": self.current_question, |
|
|
"total_questions": self.total_questions, |
|
|
"generated_tiles": self.generated_tiles, |
|
|
"start_time": self.start_time, |
|
|
"domain_id": self.domain_id |
|
|
} |
|
|
|
|
|
|
|
|
enrichment_state = EnrichmentState() |
|
|
|
|
|
|
|
|
|
|
|
def get_safe_model(model_id: Optional[str] = None): |
|
|
""" |
|
|
Get a safe model that won't cause MLX errors |
|
|
Prioritizes GGUF and Ollama models |
|
|
""" |
|
|
try: |
|
|
|
|
|
if model_id: |
|
|
model = app_model_router.config_manager.get_model_config(model_id) |
|
|
if model: |
|
|
logger.info(f"Selected specified model: {model.model_id} (provider: {model.provider})") |
|
|
return model |
|
|
logger.warning(f"Requested model {model_id} not found, falling back to safe default") |
|
|
|
|
|
|
|
|
for model in app_model_router.config_manager.models.values(): |
|
|
if model.provider == "gguf": |
|
|
logger.info(f"Auto-selected GGUF model: {model.model_id}") |
|
|
return model |
|
|
|
|
|
|
|
|
for model in app_model_router.config_manager.models.values(): |
|
|
if model.provider == "ollama": |
|
|
logger.info(f"Auto-selected Ollama model: {model.model_id}") |
|
|
return model |
|
|
|
|
|
|
|
|
for model in app_model_router.config_manager.models.values(): |
|
|
if model.provider == "huggingface": |
|
|
logger.info(f"Auto-selected HuggingFace model: {model.model_id}") |
|
|
return model |
|
|
|
|
|
|
|
|
for model in app_model_router.config_manager.models.values(): |
|
|
if model.provider != "mlx": |
|
|
logger.warning(f"Using last resort model: {model.model_id} (provider: {model.provider})") |
|
|
return model |
|
|
|
|
|
logger.error("No safe model available") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in get_safe_model: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
async def safe_llm_inference(model, prompt: str, temperature: float = 0.7) -> Dict[str, Any]: |
|
|
""" |
|
|
Safely perform LLM inference with robust error handling |
|
|
""" |
|
|
try: |
|
|
if not model: |
|
|
logger.error("No model provided for inference") |
|
|
return {"response": "", "error": "No model available"} |
|
|
|
|
|
logger.debug(f"Performing inference with model: {model.model_id} (provider: {model.provider})") |
|
|
|
|
|
|
|
|
result = await app_model_router._perform_llm_inference(model, prompt, temperature=temperature) |
|
|
|
|
|
|
|
|
if not result or "response" not in result: |
|
|
logger.warning(f"Empty or invalid response from model {model.model_id}") |
|
|
return {"response": "", "confidence": 0.0} |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during LLM inference: {e}", exc_info=True) |
|
|
return {"response": "", "error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
async def generate_enrichment_questions( |
|
|
domain_id: str, |
|
|
num_questions: int, |
|
|
focus_areas: Optional[List[str]], |
|
|
model |
|
|
) -> List[str]: |
|
|
""" |
|
|
Generate enrichment questions using the specified model |
|
|
""" |
|
|
try: |
|
|
logger.info(f"Generating {num_questions} questions for domain '{domain_id}'") |
|
|
|
|
|
|
|
|
focus_text = "" |
|
|
if focus_areas: |
|
|
focus_text = f"\nFocus on these specific areas: {', '.join(focus_areas)}" |
|
|
|
|
|
|
|
|
prompt = f"""You are an expert knowledge curator in the domain of {domain_id}. |
|
|
|
|
|
Your task is to generate {num_questions} diverse, high-quality questions that would enrich a knowledge base in this domain. |
|
|
|
|
|
Guidelines: |
|
|
1. Cover a wide range of topics within the domain |
|
|
2. Include questions at different complexity levels (basic, intermediate, advanced) |
|
|
3. Focus on practical applications and edge cases |
|
|
4. Avoid overly broad or trivial questions |
|
|
5. Each question should elicit detailed, informative answers{focus_text} |
|
|
|
|
|
Output format: |
|
|
Return ONLY a JSON array of questions, like this: |
|
|
["Question 1?", "Question 2?", ...] |
|
|
|
|
|
Domain: {domain_id} |
|
|
Number of questions: {num_questions} |
|
|
|
|
|
Generate the questions now:""" |
|
|
|
|
|
|
|
|
result = await safe_llm_inference(model, prompt, temperature=0.7) |
|
|
response = result.get("response", "") |
|
|
|
|
|
if not response: |
|
|
logger.error("Empty response from LLM for question generation") |
|
|
return [] |
|
|
|
|
|
logger.debug(f"Received response (first 200 chars): {response[:200]}") |
|
|
|
|
|
|
|
|
questions = extract_questions_from_response(response) |
|
|
|
|
|
|
|
|
questions = list(dict.fromkeys(questions)) |
|
|
|
|
|
logger.info(f"Successfully generated {len(questions)} unique questions") |
|
|
return questions[:num_questions] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to generate questions: {e}", exc_info=True) |
|
|
return [] |
|
|
|
|
|
def extract_questions_from_response(response: str) -> List[str]: |
|
|
""" |
|
|
Extract questions list from LLM response |
|
|
Handles JSON format and fallback to text parsing |
|
|
""" |
|
|
if not response or not response.strip(): |
|
|
logger.error("Empty response provided") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
json_str = "" |
|
|
|
|
|
if "```json" in response: |
|
|
json_start = response.find("```json") + 7 |
|
|
json_end = response.find("```", json_start) |
|
|
json_str = response[json_start:json_end].strip() |
|
|
elif "```" in response: |
|
|
json_start = response.find("```") + 3 |
|
|
json_end = response.find("```", json_start) |
|
|
json_str = response[json_start:json_end].strip() |
|
|
elif "[" in response and "]" in response: |
|
|
json_start = response.find("[") |
|
|
json_end = response.rfind("]") + 1 |
|
|
json_str = response[json_start:json_end] |
|
|
else: |
|
|
json_str = response.strip() |
|
|
|
|
|
if not json_str: |
|
|
logger.warning("Could not extract JSON, trying fallback parsing") |
|
|
return fallback_parse_questions(response) |
|
|
|
|
|
|
|
|
questions = json.loads(json_str) |
|
|
|
|
|
if isinstance(questions, list): |
|
|
return [q for q in questions if isinstance(q, str) and len(q) > 5] |
|
|
else: |
|
|
logger.warning(f"Expected list, got {type(questions)}") |
|
|
return fallback_parse_questions(response) |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
logger.warning(f"JSON parse error: {e}, trying fallback parsing") |
|
|
return fallback_parse_questions(response) |
|
|
|
|
|
def fallback_parse_questions(response: str) -> List[str]: |
|
|
""" |
|
|
Fallback: parse questions from plain text |
|
|
""" |
|
|
logger.info("Using fallback text parsing for questions") |
|
|
questions = [] |
|
|
|
|
|
for line in response.split('\n'): |
|
|
line = line.strip() |
|
|
if len(line) > 5 and '?' in line: |
|
|
|
|
|
cleaned = line.lstrip('0123456789.-* ').strip() |
|
|
if cleaned: |
|
|
questions.append(cleaned) |
|
|
|
|
|
logger.info(f"Fallback parsing extracted {len(questions)} questions") |
|
|
return questions |
|
|
|
|
|
|
|
|
|
|
|
async def web_search_brave(api_key: str, query: str, max_results: int) -> List[Dict[str, str]]: |
|
|
""" |
|
|
Search using Brave Search API |
|
|
""" |
|
|
try: |
|
|
import aiohttp |
|
|
|
|
|
url = "https://api.search.brave.com/res/v1/web/search" |
|
|
headers = { |
|
|
"Accept": "application/json", |
|
|
"X-Subscription-Token": api_key |
|
|
} |
|
|
params = { |
|
|
"q": query, |
|
|
"count": max_results |
|
|
} |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.get(url, headers=headers, params=params, timeout=30) as response: |
|
|
if response.status == 200: |
|
|
data = await response.json() |
|
|
results = [] |
|
|
|
|
|
for item in data.get("web", {}).get("results", []): |
|
|
results.append({ |
|
|
"title": item.get("title", ""), |
|
|
"url": item.get("url", ""), |
|
|
"snippet": item.get("description", "") |
|
|
}) |
|
|
|
|
|
return results |
|
|
else: |
|
|
logger.error(f"Brave Search API error: {response.status}") |
|
|
return [] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Brave Search failed: {e}", exc_info=True) |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
@router.post("/ai/start", response_model=EnrichmentResultResponse) |
|
|
async def start_ai_enrichment(request: AIEnrichmentRequest, background_tasks: BackgroundTasks): |
|
|
""" |
|
|
Start AI-driven enrichment process |
|
|
""" |
|
|
try: |
|
|
|
|
|
if enrichment_state.is_running: |
|
|
raise HTTPException(status_code=409, detail="Enrichment is already running") |
|
|
|
|
|
|
|
|
prompt_model = get_safe_model(request.prompt_model_id) |
|
|
if not prompt_model: |
|
|
raise HTTPException(status_code=400, detail="No safe model available for prompt generation") |
|
|
|
|
|
answer_model = get_safe_model(request.answer_model_id) |
|
|
if not answer_model: |
|
|
answer_model = prompt_model |
|
|
|
|
|
logger.info(f"Starting enrichment with prompt_model={prompt_model.model_id}, answer_model={answer_model.model_id}") |
|
|
|
|
|
|
|
|
background_tasks.add_task( |
|
|
run_ai_enrichment_task, |
|
|
request.domain_id, |
|
|
request.num_questions, |
|
|
request.focus_areas, |
|
|
prompt_model, |
|
|
answer_model |
|
|
) |
|
|
|
|
|
return EnrichmentResultResponse(success=True, domain_id=request.domain_id) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to start AI enrichment: {e}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
async def run_ai_enrichment_task( |
|
|
domain_id: str, |
|
|
num_questions: int, |
|
|
focus_areas: Optional[List[str]], |
|
|
prompt_model, |
|
|
answer_model |
|
|
): |
|
|
""" |
|
|
Background task for AI enrichment |
|
|
""" |
|
|
try: |
|
|
|
|
|
enrichment_state.reset() |
|
|
enrichment_state.is_running = True |
|
|
enrichment_state.domain_id = domain_id |
|
|
enrichment_state.total_questions = num_questions |
|
|
enrichment_state.start_time = datetime.utcnow().isoformat() |
|
|
|
|
|
|
|
|
logger.info("Step 1: Generating questions...") |
|
|
questions = await generate_enrichment_questions( |
|
|
domain_id, num_questions, focus_areas, prompt_model |
|
|
) |
|
|
|
|
|
if not questions: |
|
|
logger.error("No questions generated") |
|
|
enrichment_state.is_running = False |
|
|
return |
|
|
|
|
|
enrichment_state.total_questions = len(questions) |
|
|
|
|
|
|
|
|
logger.info(f"Step 2: Generating answers for {len(questions)} questions...") |
|
|
|
|
|
for i, question in enumerate(questions): |
|
|
|
|
|
if enrichment_state.stop_requested: |
|
|
logger.info("Enrichment stop requested, aborting") |
|
|
break |
|
|
|
|
|
logger.info(f"Processing question {i+1}/{len(questions)}: {question[:50]}...") |
|
|
|
|
|
enrichment_state.current_question = i + 1 |
|
|
enrichment_state.progress = (i / len(questions)) * 100 |
|
|
|
|
|
try: |
|
|
|
|
|
answer_result = await safe_llm_inference(answer_model, question, temperature=0.5) |
|
|
answer = answer_result.get("response", "") |
|
|
|
|
|
if not answer: |
|
|
logger.warning(f"Empty answer for question: {question}") |
|
|
continue |
|
|
|
|
|
confidence = answer_result.get("confidence", 0.7) |
|
|
|
|
|
|
|
|
coord_prompt = f"""Given this knowledge: |
|
|
Question: {question} |
|
|
Answer: {answer[:500]} |
|
|
|
|
|
Estimate 6D coordinates [x, y, z, c, g, v] where: |
|
|
- x, y, z: domain-specific 3D space (0.0-1.0 each) |
|
|
- c: certainty (0.0-1.0) |
|
|
- g: granularity (0.0-1.0) |
|
|
- v: verification level (0.0-1.0) |
|
|
|
|
|
Return only the coordinates as JSON array: [x, y, z, c, g, v]""" |
|
|
|
|
|
coord_result = await safe_llm_inference(prompt_model, coord_prompt, temperature=0.3) |
|
|
coord_response = coord_result.get("response", "") |
|
|
|
|
|
|
|
|
coordinates = [0.5, 0.5, 0.5, 0.7, 0.6, 0.8] |
|
|
if coord_response: |
|
|
try: |
|
|
|
|
|
if "[" in coord_response and "]" in coord_response: |
|
|
start = coord_response.find("[") |
|
|
end = coord_response.rfind("]") + 1 |
|
|
coord_json = coord_response[start:end] |
|
|
parsed_coords = json.loads(coord_json) |
|
|
if isinstance(parsed_coords, list) and len(parsed_coords) == 6: |
|
|
coordinates = parsed_coords |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not parse coordinates: {e}") |
|
|
|
|
|
|
|
|
tile = create_tile_from_ai_output( |
|
|
knowledge_id=f"ai_enrich_{uuid.uuid4().hex}", |
|
|
topic=question[:100], |
|
|
prompt=question, |
|
|
response=answer, |
|
|
coordinates=coordinates, |
|
|
confidence=confidence, |
|
|
domain_id=domain_id, |
|
|
source="ai_enrichment" |
|
|
) |
|
|
|
|
|
success = app_model_router.iath_writer.append_tile(tile) |
|
|
|
|
|
if success: |
|
|
enrichment_state.generated_tiles += 1 |
|
|
logger.info(f"Saved tile {enrichment_state.generated_tiles}: {question[:50]}...") |
|
|
else: |
|
|
logger.warning(f"Failed to save tile for: {question}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing question: {e}", exc_info=True) |
|
|
continue |
|
|
|
|
|
|
|
|
enrichment_state.progress = 100.0 |
|
|
enrichment_state.current_question = len(questions) |
|
|
logger.info(f"Enrichment completed: {enrichment_state.generated_tiles} tiles created") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"AI enrichment task failed: {e}", exc_info=True) |
|
|
finally: |
|
|
enrichment_state.is_running = False |
|
|
|
|
|
@router.post("/web/start", response_model=EnrichmentResultResponse) |
|
|
async def start_web_enrichment(request: WebEnrichmentRequest): |
|
|
""" |
|
|
Start web-based enrichment using Brave Search |
|
|
""" |
|
|
try: |
|
|
brave_api_key = os.getenv("BRAVE_SEARCH_API_KEY") |
|
|
if not brave_api_key: |
|
|
raise HTTPException(status_code=400, detail="Brave Search API key not configured") |
|
|
|
|
|
|
|
|
results = await web_search_brave(brave_api_key, request.query, request.max_results) |
|
|
|
|
|
if not results: |
|
|
raise HTTPException(status_code=404, detail="No search results found") |
|
|
|
|
|
|
|
|
tiles_created = 0 |
|
|
|
|
|
for result in results: |
|
|
try: |
|
|
tile_content = f"""Title: {result['title']} |
|
|
|
|
|
Source: {result['url']} |
|
|
|
|
|
Summary: |
|
|
{result['snippet']}""" |
|
|
|
|
|
|
|
|
coordinates = [0.5, 0.5, 0.5, 0.6, 0.5, 0.7] |
|
|
|
|
|
tile = create_tile_from_ai_output( |
|
|
knowledge_id=f"web_{uuid.uuid4().hex}", |
|
|
topic=result['title'], |
|
|
prompt=request.query, |
|
|
response=tile_content, |
|
|
coordinates=coordinates, |
|
|
confidence=0.6, |
|
|
domain_id=request.domain_id, |
|
|
source="web_search" |
|
|
) |
|
|
|
|
|
success = app_model_router.iath_writer.append_tile(tile) |
|
|
if success: |
|
|
tiles_created += 1 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error creating tile from search result: {e}") |
|
|
continue |
|
|
|
|
|
return EnrichmentResultResponse( |
|
|
success=True, |
|
|
tiles_created=tiles_created, |
|
|
domain_id=request.domain_id |
|
|
) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Web enrichment failed: {e}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@router.get("/status", response_model=EnrichmentStatusResponse) |
|
|
async def get_enrichment_status(): |
|
|
""" |
|
|
Get current enrichment status |
|
|
""" |
|
|
try: |
|
|
return EnrichmentStatusResponse(**enrichment_state.to_dict()) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get status: {e}", exc_info=True) |
|
|
|
|
|
return EnrichmentStatusResponse( |
|
|
is_running=False, |
|
|
progress=0.0, |
|
|
current_question=0, |
|
|
total_questions=0, |
|
|
generated_tiles=0, |
|
|
start_time=None, |
|
|
domain_id=None |
|
|
) |
|
|
|
|
|
@router.post("/stop") |
|
|
async def stop_enrichment(): |
|
|
""" |
|
|
Stop the current enrichment process |
|
|
""" |
|
|
if not enrichment_state.is_running: |
|
|
raise HTTPException(status_code=400, detail="No enrichment is currently running") |
|
|
|
|
|
enrichment_state.stop_requested = True |
|
|
logger.info("Enrichment stop requested") |
|
|
return {"success": True, "message": "Enrichment stop requested"} |
|
|
|
|
|
@router.post("/ai/generate-questions", response_model=Dict[str, Any]) |
|
|
async def generate_questions_only(request: AIEnrichmentRequest): |
|
|
""" |
|
|
Generate questions only (no answers) |
|
|
""" |
|
|
try: |
|
|
model = get_safe_model(request.prompt_model_id) |
|
|
if not model: |
|
|
raise HTTPException(status_code=400, detail="No safe model available") |
|
|
|
|
|
questions = await generate_enrichment_questions( |
|
|
request.domain_id, |
|
|
request.num_questions, |
|
|
request.focus_areas, |
|
|
model |
|
|
) |
|
|
|
|
|
return { |
|
|
"questions": questions, |
|
|
"domain_id": request.domain_id, |
|
|
"count": len(questions) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to generate questions: {e}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@router.get("/config") |
|
|
async def get_enrichment_config(): |
|
|
""" |
|
|
Get enrichment configuration |
|
|
""" |
|
|
brave_api_key = os.getenv("BRAVE_SEARCH_API_KEY") |
|
|
|
|
|
|
|
|
safe_models = [] |
|
|
for model in app_model_router.config_manager.models.values(): |
|
|
if model.provider != "mlx": |
|
|
safe_models.append({ |
|
|
"id": model.model_id, |
|
|
"name": model.model_name, |
|
|
"provider": model.provider |
|
|
}) |
|
|
|
|
|
master_model_id = app_model_router.master_model.model_id if app_model_router.master_model else None |
|
|
|
|
|
return { |
|
|
"ai_enrichment": { |
|
|
"available": len(safe_models) > 0, |
|
|
"available_models": safe_models, |
|
|
"master_model_id": master_model_id |
|
|
}, |
|
|
"web_enrichment": { |
|
|
"available": bool(brave_api_key), |
|
|
"search_engine": "brave", |
|
|
"api_key_configured": bool(brave_api_key) |
|
|
} |
|
|
} |
|
|
|