DocUA commited on
Commit
a09e8cb
·
1 Parent(s): 5c25abd

Enhance provider management: normalize response keys, add availability checks, and improve Gradio interface for AI providers

Browse files
Files changed (4) hide show
  1. app.py +25 -10
  2. interface.py +35 -4
  3. main.py +35 -0
  4. test_providers.py +45 -0
app.py CHANGED
@@ -12,8 +12,12 @@ os.environ['GRADIO_SERVER_PORT'] = '7860'
12
  # Avoid uvloop shutdown warnings on HF Spaces
13
  os.environ.setdefault('UVICORN_LOOP', 'asyncio')
14
 
15
- import nest_asyncio
16
- nest_asyncio.apply()
 
 
 
 
17
 
18
  # Add project root to Python path
19
  project_root = Path(__file__).parent
@@ -93,14 +97,25 @@ if __name__ == "__main__":
93
  # Run diagnostics only when executed directly
94
  run_network_diagnostics()
95
 
96
- print("🚀 Starting Legal Position AI Analyzer on Hugging Face Spaces...")
 
 
 
97
 
98
  # Must call launch() explicitly — Gradio 6 does not auto-launch.
99
  # ssr_mode=False avoids the "shareable link" error on HF Spaces containers.
100
- demo.launch(
101
- server_name="0.0.0.0",
102
- server_port=7860,
103
- share=False,
104
- show_error=True,
105
- ssr_mode=False,
106
- )
 
 
 
 
 
 
 
 
 
12
  # Avoid uvloop shutdown warnings on HF Spaces
13
  os.environ.setdefault('UVICORN_LOOP', 'asyncio')
14
 
15
+ # Apply nest_asyncio only if needed (some Python versions have conflicts)
16
+ # try:
17
+ # import nest_asyncio
18
+ # nest_asyncio.apply()
19
+ # except Exception as e:
20
+ # print(f"[WARNING] Could not apply nest_asyncio: {e}")
21
 
22
  # Add project root to Python path
23
  project_root = Path(__file__).parent
 
97
  # Run diagnostics only when executed directly
98
  run_network_diagnostics()
99
 
100
+ print("🚀 Starting Legal Position AI Analyzer...")
101
+
102
+ # Detect if running on HF Spaces or locally
103
+ is_hf_space = os.environ.get('SPACE_ID') is not None
104
 
105
  # Must call launch() explicitly — Gradio 6 does not auto-launch.
106
  # ssr_mode=False avoids the "shareable link" error on HF Spaces containers.
107
+ if is_hf_space:
108
+ # On HF Spaces, use fixed port 7860
109
+ demo.launch(
110
+ server_name="0.0.0.0",
111
+ server_port=7860,
112
+ share=False,
113
+ show_error=True,
114
+ ssr_mode=False,
115
+ )
116
+ else:
117
+ # Locally, let Gradio find an available port
118
+ demo.launch(
119
+ share=False,
120
+ show_error=True,
121
+ )
interface.py CHANGED
@@ -14,7 +14,8 @@ from main import (
14
  generate_legal_position,
15
  search_with_ai_action,
16
  analyze_action,
17
- search_with_raw_text
 
18
  )
19
  from prompts import SYSTEM_PROMPT, LEGAL_POSITION_PROMPT, PRECEDENT_ANALYSIS_TEMPLATE
20
  from src.session.manager import get_session_manager
@@ -32,6 +33,12 @@ def load_help_content() -> str:
32
  return f"Помилка завантаження довідки: {str(e)}"
33
 
34
 
 
 
 
 
 
 
35
  def update_generation_model_choices(provider: str) -> gr.Dropdown:
36
  """Update generation model choices based on provider selection."""
37
  if provider == ModelProvider.OPENAI.value:
@@ -466,6 +473,18 @@ def create_gradio_interface() -> gr.Blocks:
466
  except Exception:
467
  _default_provider = "anthropic"
468
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  # Get default generation model for the provider
470
  _gen_models = get_generation_models_by_provider(_default_provider)
471
  if DEFAULT_GENERATION_MODEL and DEFAULT_GENERATION_MODEL.value in _gen_models:
@@ -538,6 +557,18 @@ def create_gradio_interface() -> gr.Blocks:
538
  </div>
539
  """
540
  )
 
 
 
 
 
 
 
 
 
 
 
 
541
 
542
  # Session state - generates unique ID for each browser session
543
  session_id_state = gr.State(value=generate_session_id)
@@ -563,7 +594,7 @@ def create_gradio_interface() -> gr.Blocks:
563
  gr.Markdown("### 🤖 Налаштування моделі")
564
  with gr.Row():
565
  generation_provider_dropdown = gr.Dropdown(
566
- choices=[p.value for p in ModelProvider],
567
  value=_default_provider,
568
  label="Провайдер AI",
569
  container=False,
@@ -680,7 +711,7 @@ def create_gradio_interface() -> gr.Blocks:
680
 
681
  with gr.Row():
682
  analysis_provider_dropdown = gr.Dropdown(
683
- choices=[p.value for p in ModelProvider],
684
  value=_default_provider,
685
  label="Провайдер AI",
686
  scale=1
@@ -781,7 +812,7 @@ def create_gradio_interface() -> gr.Blocks:
781
 
782
  with gr.Row():
783
  batch_provider_dropdown = gr.Dropdown(
784
- choices=[p.value for p in ModelProvider],
785
  value=_default_provider,
786
  label="Провайдер AI",
787
  scale=1
 
14
  generate_legal_position,
15
  search_with_ai_action,
16
  analyze_action,
17
+ search_with_raw_text,
18
+ get_available_providers
19
  )
20
  from prompts import SYSTEM_PROMPT, LEGAL_POSITION_PROMPT, PRECEDENT_ANALYSIS_TEMPLATE
21
  from src.session.manager import get_session_manager
 
33
  return f"Помилка завантаження довідки: {str(e)}"
34
 
35
 
36
+ def get_available_provider_choices() -> list:
37
+ """Get list of available AI providers based on API key availability."""
38
+ available = get_available_providers()
39
+ return [p.value for p in ModelProvider if available.get(p.value, False)]
40
+
41
+
42
  def update_generation_model_choices(provider: str) -> gr.Dropdown:
43
  """Update generation model choices based on provider selection."""
44
  if provider == ModelProvider.OPENAI.value:
 
473
  except Exception:
474
  _default_provider = "anthropic"
475
 
476
+ # Get available providers based on API key availability
477
+ _available_providers = get_available_provider_choices()
478
+
479
+ # If default provider is not available, use first available one
480
+ if _default_provider not in _available_providers:
481
+ if _available_providers:
482
+ _default_provider = _available_providers[0]
483
+ print(f"[WARNING] Default provider not available, using: {_default_provider}")
484
+ else:
485
+ print("[ERROR] No AI providers available! Please set at least one API key.")
486
+ _default_provider = "anthropic" # Fallback for UI rendering
487
+
488
  # Get default generation model for the provider
489
  _gen_models = get_generation_models_by_provider(_default_provider)
490
  if DEFAULT_GENERATION_MODEL and DEFAULT_GENERATION_MODEL.value in _gen_models:
 
557
  </div>
558
  """
559
  )
560
+
561
+ # Show provider availability status
562
+ _all_providers = {p.value for p in ModelProvider}
563
+ _unavailable = _all_providers - set(_available_providers)
564
+ if _unavailable:
565
+ unavailable_list = ", ".join(sorted(_unavailable))
566
+ gr.Info(
567
+ f"⚠️ Недоступні провайдери (відсутні API ключі): {unavailable_list}\n"
568
+ f"Додайте відповідні API ключі в налаштуваннях HF Space для активації.",
569
+ title="Інформація про провайдери",
570
+ duration=10
571
+ )
572
 
573
  # Session state - generates unique ID for each browser session
574
  session_id_state = gr.State(value=generate_session_id)
 
594
  gr.Markdown("### 🤖 Налаштування моделі")
595
  with gr.Row():
596
  generation_provider_dropdown = gr.Dropdown(
597
+ choices=_available_providers,
598
  value=_default_provider,
599
  label="Провайдер AI",
600
  container=False,
 
711
 
712
  with gr.Row():
713
  analysis_provider_dropdown = gr.Dropdown(
714
+ choices=_available_providers,
715
  value=_default_provider,
716
  label="Провайдер AI",
717
  scale=1
 
812
 
813
  with gr.Row():
814
  batch_provider_dropdown = gr.Dropdown(
815
+ choices=_available_providers,
816
  value=_default_provider,
817
  label="Провайдер AI",
818
  scale=1
main.py CHANGED
@@ -236,6 +236,31 @@ def check_provider_available(provider: str) -> Tuple[bool, str]:
236
  return True, ""
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  class RetrieverEvent(Event):
240
  """Event class for retriever operations."""
241
  nodes: list[NodeWithScore]
@@ -716,6 +741,9 @@ def generate_legal_position(
716
  print(f"[DEBUG] OpenAI response length: {len(response_text) if response_text else 0}")
717
 
718
  json_response = extract_json_from_text(response_text)
 
 
 
719
  if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
720
  return json_response
721
  else:
@@ -784,6 +812,9 @@ def generate_legal_position(
784
  print(f"[DEBUG] DeepSeek response length: {len(response_text) if response_text else 0}")
785
 
786
  json_response = extract_json_from_text(response_text)
 
 
 
787
  if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
788
  return json_response
789
  else:
@@ -872,6 +903,8 @@ def generate_legal_position(
872
  json_response = extract_json_from_text(response_text)
873
 
874
  if json_response:
 
 
875
  # Validate required fields
876
  required = ["title", "text", "proceeding", "category"]
877
  missing = [f for f in required if f not in json_response]
@@ -959,6 +992,8 @@ def generate_legal_position(
959
  json_response = extract_json_from_text(response_text)
960
 
961
  if json_response:
 
 
962
  # Перевіряємо наявність всіх необхідних полів
963
  required_fields = ["title", "text", "proceeding", "category"]
964
  if all(field in json_response for field in required_fields):
 
236
  return True, ""
237
 
238
 
239
+ def normalize_response_keys(response_dict: Dict[str, Any]) -> Dict[str, Any]:
240
+ """
241
+ Normalize keys in the response dictionary to match the expected format.
242
+ Handles variations like 'text_lp' -> 'text' and 'proceeding_type' -> 'proceeding'.
243
+ """
244
+ if not response_dict:
245
+ return response_dict
246
+
247
+ # Map common variations to standard keys
248
+ key_mapping = {
249
+ "text_lp": "text",
250
+ "legal_position_text": "text",
251
+ "lp_text": "text",
252
+ "proceeding_type": "proceeding",
253
+ "type_of_proceeding": "proceeding"
254
+ }
255
+
256
+ normalized = response_dict.copy()
257
+ for variant, standard in key_mapping.items():
258
+ if variant in normalized and standard not in normalized:
259
+ normalized[standard] = normalized.pop(variant)
260
+
261
+ return normalized
262
+
263
+
264
  class RetrieverEvent(Event):
265
  """Event class for retriever operations."""
266
  nodes: list[NodeWithScore]
 
741
  print(f"[DEBUG] OpenAI response length: {len(response_text) if response_text else 0}")
742
 
743
  json_response = extract_json_from_text(response_text)
744
+ if json_response:
745
+ json_response = normalize_response_keys(json_response)
746
+
747
  if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
748
  return json_response
749
  else:
 
812
  print(f"[DEBUG] DeepSeek response length: {len(response_text) if response_text else 0}")
813
 
814
  json_response = extract_json_from_text(response_text)
815
+ if json_response:
816
+ json_response = normalize_response_keys(json_response)
817
+
818
  if json_response and all(key in json_response for key in ["title", "text", "proceeding", "category"]):
819
  return json_response
820
  else:
 
903
  json_response = extract_json_from_text(response_text)
904
 
905
  if json_response:
906
+ json_response = normalize_response_keys(json_response)
907
+
908
  # Validate required fields
909
  required = ["title", "text", "proceeding", "category"]
910
  missing = [f for f in required if f not in json_response]
 
992
  json_response = extract_json_from_text(response_text)
993
 
994
  if json_response:
995
+ json_response = normalize_response_keys(json_response)
996
+
997
  # Перевіряємо наявність всіх необхідних полів
998
  required_fields = ["title", "text", "proceeding", "category"]
999
  if all(field in json_response for field in required_fields):
test_providers.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to check available AI providers
4
+ """
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ # Load environment variables
9
+ load_dotenv()
10
+
11
+ def get_available_providers():
12
+ """Get status of all AI providers."""
13
+ return {
14
+ "openai": bool(os.getenv("OPENAI_API_KEY")),
15
+ "anthropic": bool(os.getenv("ANTHROPIC_API_KEY")),
16
+ "gemini": bool(os.getenv("GEMINI_API_KEY")),
17
+ "deepseek": bool(os.getenv("DEEPSEEK_API_KEY"))
18
+ }
19
+
20
+ if __name__ == "__main__":
21
+ providers = get_available_providers()
22
+
23
+ print("=" * 50)
24
+ print("🔑 API Keys Status")
25
+ print("=" * 50)
26
+
27
+ for provider, available in providers.items():
28
+ status = "✅ Available" if available else "❌ Missing"
29
+ print(f"{provider.upper():12} : {status}")
30
+
31
+ print("=" * 50)
32
+
33
+ available_list = [p for p, avail in providers.items() if avail]
34
+ unavailable_list = [p for p, avail in providers.items() if not avail]
35
+
36
+ print(f"\n✅ Available providers: {', '.join(available_list) if available_list else 'None'}")
37
+ print(f"❌ Unavailable providers: {', '.join(unavailable_list) if unavailable_list else 'None'}")
38
+
39
+ if not available_list:
40
+ print("\n⚠️ WARNING: No AI providers available!")
41
+ print("Please set at least one API key in your .env file:")
42
+ print(" - OPENAI_API_KEY")
43
+ print(" - ANTHROPIC_API_KEY")
44
+ print(" - GEMINI_API_KEY")
45
+ print(" - DEEPSEEK_API_KEY")