Enhance provider management: normalize response keys, add availability checks, and improve Gradio interface for AI providers
Browse files- app.py +25 -10
- interface.py +35 -4
- main.py +35 -0
- test_providers.py +45 -0
app.py
CHANGED
|
@@ -12,8 +12,12 @@ os.environ['GRADIO_SERVER_PORT'] = '7860'
|
|
| 12 |
# Avoid uvloop shutdown warnings on HF Spaces
|
| 13 |
os.environ.setdefault('UVICORN_LOOP', 'asyncio')
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Add project root to Python path
|
| 19 |
project_root = Path(__file__).parent
|
|
@@ -93,14 +97,25 @@ if __name__ == "__main__":
|
|
| 93 |
# Run diagnostics only when executed directly
|
| 94 |
run_network_diagnostics()
|
| 95 |
|
| 96 |
-
print("🚀 Starting Legal Position AI Analyzer
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# Must call launch() explicitly — Gradio 6 does not auto-launch.
|
| 99 |
# ssr_mode=False avoids the "shareable link" error on HF Spaces containers.
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Avoid uvloop shutdown warnings on HF Spaces
|
| 13 |
os.environ.setdefault('UVICORN_LOOP', 'asyncio')
|
| 14 |
|
| 15 |
+
# Apply nest_asyncio only if needed (some Python versions have conflicts)
|
| 16 |
+
# try:
|
| 17 |
+
# import nest_asyncio
|
| 18 |
+
# nest_asyncio.apply()
|
| 19 |
+
# except Exception as e:
|
| 20 |
+
# print(f"[WARNING] Could not apply nest_asyncio: {e}")
|
| 21 |
|
| 22 |
# Add project root to Python path
|
| 23 |
project_root = Path(__file__).parent
|
|
|
|
| 97 |
# Run diagnostics only when executed directly
|
| 98 |
run_network_diagnostics()
|
| 99 |
|
| 100 |
+
print("🚀 Starting Legal Position AI Analyzer...")
|
| 101 |
+
|
| 102 |
+
# Detect if running on HF Spaces or locally
|
| 103 |
+
is_hf_space = os.environ.get('SPACE_ID') is not None
|
| 104 |
|
| 105 |
# Must call launch() explicitly — Gradio 6 does not auto-launch.
|
| 106 |
# ssr_mode=False avoids the "shareable link" error on HF Spaces containers.
|
| 107 |
+
if is_hf_space:
|
| 108 |
+
# On HF Spaces, use fixed port 7860
|
| 109 |
+
demo.launch(
|
| 110 |
+
server_name="0.0.0.0",
|
| 111 |
+
server_port=7860,
|
| 112 |
+
share=False,
|
| 113 |
+
show_error=True,
|
| 114 |
+
ssr_mode=False,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
# Locally, let Gradio find an available port
|
| 118 |
+
demo.launch(
|
| 119 |
+
share=False,
|
| 120 |
+
show_error=True,
|
| 121 |
+
)
|
interface.py
CHANGED
|
@@ -14,7 +14,8 @@ from main import (
|
|
| 14 |
generate_legal_position,
|
| 15 |
search_with_ai_action,
|
| 16 |
analyze_action,
|
| 17 |
-
search_with_raw_text
|
|
|
|
| 18 |
)
|
| 19 |
from prompts import SYSTEM_PROMPT, LEGAL_POSITION_PROMPT, PRECEDENT_ANALYSIS_TEMPLATE
|
| 20 |
from src.session.manager import get_session_manager
|
|
@@ -32,6 +33,12 @@ def load_help_content() -> str:
|
|
| 32 |
return f"Помилка завантаження довідки: {str(e)}"
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def update_generation_model_choices(provider: str) -> gr.Dropdown:
|
| 36 |
"""Update generation model choices based on provider selection."""
|
| 37 |
if provider == ModelProvider.OPENAI.value:
|
|
@@ -466,6 +473,18 @@ def create_gradio_interface() -> gr.Blocks:
|
|
| 466 |
except Exception:
|
| 467 |
_default_provider = "anthropic"
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
# Get default generation model for the provider
|
| 470 |
_gen_models = get_generation_models_by_provider(_default_provider)
|
| 471 |
if DEFAULT_GENERATION_MODEL and DEFAULT_GENERATION_MODEL.value in _gen_models:
|
|
@@ -538,6 +557,18 @@ def create_gradio_interface() -> gr.Blocks:
|
|
| 538 |
</div>
|
| 539 |
"""
|
| 540 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
# Session state - generates unique ID for each browser session
|
| 543 |
session_id_state = gr.State(value=generate_session_id)
|
|
@@ -563,7 +594,7 @@ def create_gradio_interface() -> gr.Blocks:
|
|
| 563 |
gr.Markdown("### 🤖 Налаштування моделі")
|
| 564 |
with gr.Row():
|
| 565 |
generation_provider_dropdown = gr.Dropdown(
|
| 566 |
-
choices=
|
| 567 |
value=_default_provider,
|
| 568 |
label="Провайдер AI",
|
| 569 |
container=False,
|
|
@@ -680,7 +711,7 @@ def create_gradio_interface() -> gr.Blocks:
|
|
| 680 |
|
| 681 |
with gr.Row():
|
| 682 |
analysis_provider_dropdown = gr.Dropdown(
|
| 683 |
-
choices=
|
| 684 |
value=_default_provider,
|
| 685 |
label="Провайдер AI",
|
| 686 |
scale=1
|
|
@@ -781,7 +812,7 @@ def create_gradio_interface() -> gr.Blocks:
|
|
| 781 |
|
| 782 |
with gr.Row():
|
| 783 |
batch_provider_dropdown = gr.Dropdown(
|
| 784 |
-
choices=
|
| 785 |
value=_default_provider,
|
| 786 |
label="Провайдер AI",
|
| 787 |
scale=1
|
|
|
|
| 14 |
generate_legal_position,
|
| 15 |
search_with_ai_action,
|
| 16 |
analyze_action,
|
| 17 |
+
search_with_raw_text,
|
| 18 |
+
get_available_providers
|
| 19 |
)
|
| 20 |
from prompts import SYSTEM_PROMPT, LEGAL_POSITION_PROMPT, PRECEDENT_ANALYSIS_TEMPLATE
|
| 21 |
from src.session.manager import get_session_manager
|
|
|
|
| 33 |
return f"Помилка завантаження довідки: {str(e)}"
|
| 34 |
|
| 35 |
|
| 36 |
+
def get_available_provider_choices() -> list:
|
| 37 |
+
"""Get list of available AI providers based on API key availability."""
|
| 38 |
+
available = get_available_providers()
|
| 39 |
+
return [p.value for p in ModelProvider if available.get(p.value, False)]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def update_generation_model_choices(provider: str) -> gr.Dropdown:
|
| 43 |
"""Update generation model choices based on provider selection."""
|
| 44 |
if provider == ModelProvider.OPENAI.value:
|
|
|
|
| 473 |
except Exception:
|
| 474 |
_default_provider = "anthropic"
|
| 475 |
|
| 476 |
+
# Get available providers based on API key availability
|
| 477 |
+
_available_providers = get_available_provider_choices()
|
| 478 |
+
|
| 479 |
+
# If default provider is not available, use first available one
|
| 480 |
+
if _default_provider not in _available_providers:
|
| 481 |
+
if _available_providers:
|
| 482 |
+
_default_provider = _available_providers[0]
|
| 483 |
+
print(f"[WARNING] Default provider not available, using: {_default_provider}")
|
| 484 |
+
else:
|
| 485 |
+
print("[ERROR] No AI providers available! Please set at least one API key.")
|
| 486 |
+
_default_provider = "anthropic" # Fallback for UI rendering
|
| 487 |
+
|
| 488 |
# Get default generation model for the provider
|
| 489 |
_gen_models = get_generation_models_by_provider(_default_provider)
|
| 490 |
if DEFAULT_GENERATION_MODEL and DEFAULT_GENERATION_MODEL.value in _gen_models:
|
|
|
|
| 557 |
</div>
|
| 558 |
"""
|
| 559 |
)
|
| 560 |
+
|
| 561 |
+
# Show provider availability status
|
| 562 |
+
_all_providers = {p.value for p in ModelProvider}
|
| 563 |
+
_unavailable = _all_providers - set(_available_providers)
|
| 564 |
+
if _unavailable:
|
| 565 |
+
unavailable_list = ", ".join(sorted(_unavailable))
|
| 566 |
+
gr.Info(
|
| 567 |
+
f"⚠️ Недоступні провайдери (відсутні API ключі): {unavailable_list}\n"
|
| 568 |
+
f"Додайте відповідні API ключі в налаштуваннях HF Space для активації.",
|
| 569 |
+
title="Інформація про провайдери",
|
| 570 |
+
duration=10
|
| 571 |
+
)
|
| 572 |
|
| 573 |
# Session state - generates unique ID for each browser session
|
| 574 |
session_id_state = gr.State(value=generate_session_id)
|
|
|
|
| 594 |
gr.Markdown("### 🤖 Налаштування моделі")
|
| 595 |
with gr.Row():
|
| 596 |
generation_provider_dropdown = gr.Dropdown(
|
| 597 |
+
choices=_available_providers,
|
| 598 |
value=_default_provider,
|
| 599 |
label="Провайдер AI",
|
| 600 |
container=False,
|
|
|
|
| 711 |
|
| 712 |
with gr.Row():
|
| 713 |
analysis_provider_dropdown = gr.Dropdown(
|
| 714 |
+
choices=_available_providers,
|
| 715 |
value=_default_provider,
|
| 716 |
label="Провайдер AI",
|
| 717 |
scale=1
|
|
|
|
| 812 |
|
| 813 |
with gr.Row():
|
| 814 |
batch_provider_dropdown = gr.Dropdown(
|
| 815 |
+
choices=_available_providers,
|
| 816 |
value=_default_provider,
|
| 817 |
label="Провайдер AI",
|
| 818 |
scale=1
|
main.py
CHANGED
|
@@ -236,6 +236,31 @@ def check_provider_available(provider: str) -> Tuple[bool, str]:
|
|
| 236 |
return True, ""
|
| 237 |
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
class RetrieverEvent(Event):
|
| 240 |
"""Event class for retriever operations."""
|
| 241 |
nodes: list[NodeWithScore]
|
|
@@ -716,6 +741,9 @@ def generate_legal_position(
|
|
| 716 |
print(f"[DEBUG] OpenAI response length: {len(response_text) if response_text else 0}")
|
| 717 |
|
| 718 |
json_response = extract_json_from_text(response_text)
|
|
|
|
|
|
|
|
|
|
| 719 |
if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
|
| 720 |
return json_response
|
| 721 |
else:
|
|
@@ -784,6 +812,9 @@ def generate_legal_position(
|
|
| 784 |
print(f"[DEBUG] DeepSeek response length: {len(response_text) if response_text else 0}")
|
| 785 |
|
| 786 |
json_response = extract_json_from_text(response_text)
|
|
|
|
|
|
|
|
|
|
| 787 |
if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
|
| 788 |
return json_response
|
| 789 |
else:
|
|
@@ -872,6 +903,8 @@ def generate_legal_position(
|
|
| 872 |
json_response = extract_json_from_text(response_text)
|
| 873 |
|
| 874 |
if json_response:
|
|
|
|
|
|
|
| 875 |
# Validate required fields
|
| 876 |
required = ["title", "text", "proceeding", "category"]
|
| 877 |
missing = [f for f in required if f not in json_response]
|
|
@@ -959,6 +992,8 @@ def generate_legal_position(
|
|
| 959 |
json_response = extract_json_from_text(response_text)
|
| 960 |
|
| 961 |
if json_response:
|
|
|
|
|
|
|
| 962 |
# Перевіряємо наявність всіх необхідних полів
|
| 963 |
required_fields = ["title", "text", "proceeding", "category"]
|
| 964 |
if all(field in json_response for field in required_fields):
|
|
|
|
| 236 |
return True, ""
|
| 237 |
|
| 238 |
|
| 239 |
+
def normalize_response_keys(response_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 240 |
+
"""
|
| 241 |
+
Normalize keys in the response dictionary to match the expected format.
|
| 242 |
+
Handles variations like 'text_lp' -> 'text' and 'proceeding_type' -> 'proceeding'.
|
| 243 |
+
"""
|
| 244 |
+
if not response_dict:
|
| 245 |
+
return response_dict
|
| 246 |
+
|
| 247 |
+
# Map common variations to standard keys
|
| 248 |
+
key_mapping = {
|
| 249 |
+
"text_lp": "text",
|
| 250 |
+
"legal_position_text": "text",
|
| 251 |
+
"lp_text": "text",
|
| 252 |
+
"proceeding_type": "proceeding",
|
| 253 |
+
"type_of_proceeding": "proceeding"
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
normalized = response_dict.copy()
|
| 257 |
+
for variant, standard in key_mapping.items():
|
| 258 |
+
if variant in normalized and standard not in normalized:
|
| 259 |
+
normalized[standard] = normalized.pop(variant)
|
| 260 |
+
|
| 261 |
+
return normalized
|
| 262 |
+
|
| 263 |
+
|
| 264 |
class RetrieverEvent(Event):
|
| 265 |
"""Event class for retriever operations."""
|
| 266 |
nodes: list[NodeWithScore]
|
|
|
|
| 741 |
print(f"[DEBUG] OpenAI response length: {len(response_text) if response_text else 0}")
|
| 742 |
|
| 743 |
json_response = extract_json_from_text(response_text)
|
| 744 |
+
if json_response:
|
| 745 |
+
json_response = normalize_response_keys(json_response)
|
| 746 |
+
|
| 747 |
if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
|
| 748 |
return json_response
|
| 749 |
else:
|
|
|
|
| 812 |
print(f"[DEBUG] DeepSeek response length: {len(response_text) if response_text else 0}")
|
| 813 |
|
| 814 |
json_response = extract_json_from_text(response_text)
|
| 815 |
+
if json_response:
|
| 816 |
+
json_response = normalize_response_keys(json_response)
|
| 817 |
+
|
| 818 |
if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
|
| 819 |
return json_response
|
| 820 |
else:
|
|
|
|
| 903 |
json_response = extract_json_from_text(response_text)
|
| 904 |
|
| 905 |
if json_response:
|
| 906 |
+
json_response = normalize_response_keys(json_response)
|
| 907 |
+
|
| 908 |
# Validate required fields
|
| 909 |
required = ["title", "text", "proceeding", "category"]
|
| 910 |
missing = [f for f in required if f not in json_response]
|
|
|
|
| 992 |
json_response = extract_json_from_text(response_text)
|
| 993 |
|
| 994 |
if json_response:
|
| 995 |
+
json_response = normalize_response_keys(json_response)
|
| 996 |
+
|
| 997 |
# Перевіряємо наявність всіх необхідних полів
|
| 998 |
required_fields = ["title", "text", "proceeding", "category"]
|
| 999 |
if all(field in json_response for field in required_fields):
|
test_providers.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script to check available AI providers
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
# Load environment variables
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
def get_available_providers():
|
| 12 |
+
"""Get status of all AI providers."""
|
| 13 |
+
return {
|
| 14 |
+
"openai": bool(os.getenv("OPENAI_API_KEY")),
|
| 15 |
+
"anthropic": bool(os.getenv("ANTHROPIC_API_KEY")),
|
| 16 |
+
"gemini": bool(os.getenv("GEMINI_API_KEY")),
|
| 17 |
+
"deepseek": bool(os.getenv("DEEPSEEK_API_KEY"))
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
providers = get_available_providers()
|
| 22 |
+
|
| 23 |
+
print("=" * 50)
|
| 24 |
+
print("🔑 API Keys Status")
|
| 25 |
+
print("=" * 50)
|
| 26 |
+
|
| 27 |
+
for provider, available in providers.items():
|
| 28 |
+
status = "✅ Available" if available else "❌ Missing"
|
| 29 |
+
print(f"{provider.upper():12} : {status}")
|
| 30 |
+
|
| 31 |
+
print("=" * 50)
|
| 32 |
+
|
| 33 |
+
available_list = [p for p, avail in providers.items() if avail]
|
| 34 |
+
unavailable_list = [p for p, avail in providers.items() if not avail]
|
| 35 |
+
|
| 36 |
+
print(f"\n✅ Available providers: {', '.join(available_list) if available_list else 'None'}")
|
| 37 |
+
print(f"❌ Unavailable providers: {', '.join(unavailable_list) if unavailable_list else 'None'}")
|
| 38 |
+
|
| 39 |
+
if not available_list:
|
| 40 |
+
print("\n⚠️ WARNING: No AI providers available!")
|
| 41 |
+
print("Please set at least one API key in your .env file:")
|
| 42 |
+
print(" - OPENAI_API_KEY")
|
| 43 |
+
print(" - ANTHROPIC_API_KEY")
|
| 44 |
+
print(" - GEMINI_API_KEY")
|
| 45 |
+
print(" - DEEPSEEK_API_KEY")
|