cheenchan commited on
Commit
5ad097b
Β·
1 Parent(s): 0f0e529

Optimize pipeline for fast responses - disable RL overhead for instant character extraction

Browse files
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Make src a Python package
src/app.py CHANGED
@@ -21,10 +21,10 @@ try:
21
  warnings.filterwarnings("ignore", message=".*protobuf.*")
22
  warnings.filterwarnings("ignore", message=".*MessageFactory.*")
23
 
24
- from character_pipeline import create_pipeline
25
- from pipeline import CharacterAttributes
26
- from pipeline.input_loader import DatasetItem
27
- from rl_trainer import train_rl_pipeline
28
  PIPELINE_AVAILABLE = True
29
  except (ImportError, AttributeError) as e:
30
  logging.warning(f"Pipeline dependencies not available: {e}")
@@ -76,10 +76,14 @@ class UnifiedCharacterExtractionApp:
76
  try:
77
  if PIPELINE_AVAILABLE:
78
  self.pipeline = create_pipeline({
79
- 'use_rl_primary': True,
80
- 'rl_model_path': 'decision_transformer.pth' if Path('decision_transformer.pth').exists() else None
 
 
 
 
81
  })
82
- logger.info("RL Pipeline initialized successfully")
83
  else:
84
  self.pipeline = None
85
  logger.info("Running in fallback mode - dependencies loading...")
@@ -436,8 +440,10 @@ def main():
436
 
437
  port = int(os.environ.get("PORT", 7860))
438
 
 
 
439
  interface.launch(
440
- server_name="0.0.0.0",
441
  server_port=port,
442
  share=False,
443
  show_error=True
 
21
  warnings.filterwarnings("ignore", message=".*protobuf.*")
22
  warnings.filterwarnings("ignore", message=".*MessageFactory.*")
23
 
24
+ from src.character_pipeline import create_pipeline
25
+ from src.pipeline import CharacterAttributes
26
+ from src.pipeline.input_loader import DatasetItem
27
+ from src.rl_trainer import train_rl_pipeline
28
  PIPELINE_AVAILABLE = True
29
  except (ImportError, AttributeError) as e:
30
  logging.warning(f"Pipeline dependencies not available: {e}")
 
76
  try:
77
  if PIPELINE_AVAILABLE:
78
  self.pipeline = create_pipeline({
79
+ 'use_rl_primary': False,
80
+ 'rl_model_path': None,
81
+ 'enable_caching': True,
82
+ 'batch_size': 1,
83
+ 'fast_mode': True,
84
+ 'disable_ray': True
85
  })
86
+ logger.info("Fast Pipeline initialized successfully")
87
  else:
88
  self.pipeline = None
89
  logger.info("Running in fallback mode - dependencies loading...")
 
440
 
441
  port = int(os.environ.get("PORT", 7860))
442
 
443
+ interface.queue() # Enable queue for Gradio 3.50.0
444
+
445
  interface.launch(
446
+ server_name="127.0.0.1",
447
  server_port=port,
448
  share=False,
449
  show_error=True
src/app_simple_working.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import time
4
+ import os
5
+ from pathlib import Path
6
+ from PIL import Image
7
+ from typing import Dict, List, Tuple, Any
8
+ import logging
9
+ import sys
10
+
11
+ # Add src to path for imports
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ # Simple imports without complex dependencies
15
+ try:
16
+ from src.character_pipeline import create_pipeline
17
+ PIPELINE_AVAILABLE = True
18
+ print("βœ… RL Pipeline loaded successfully!")
19
+ except Exception as e:
20
+ print(f"⚠️ Pipeline not available: {e}")
21
+ PIPELINE_AVAILABLE = False
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ class SimpleCharacterApp:
27
+ def __init__(self):
28
+ self.pipeline = None
29
+ if PIPELINE_AVAILABLE:
30
+ try:
31
+ self.pipeline = create_pipeline({
32
+ 'use_rl_primary': True,
33
+ 'rl_model_path': None
34
+ })
35
+ logger.info("βœ… RL Pipeline initialized successfully")
36
+ except Exception as e:
37
+ logger.error(f"❌ Pipeline initialization failed: {e}")
38
+ self.pipeline = None
39
+
40
+ def extract_attributes(self, image):
41
+ if image is None:
42
+ return "Please upload an image first.", "{}", "No image provided"
43
+
44
+ try:
45
+ start_time = time.time()
46
+
47
+ if self.pipeline and PIPELINE_AVAILABLE:
48
+ # Use real RL pipeline
49
+ attributes = self.pipeline.extract_from_image(image)
50
+ processing_time = time.time() - start_time
51
+
52
+ # Format output
53
+ formatted_output = "**🎭 Character Attributes Extracted:**\n\n"
54
+ attr_dict = attributes.to_dict() if hasattr(attributes, 'to_dict') else {
55
+ "Age": getattr(attributes, 'age', 'Unknown'),
56
+ "Gender": getattr(attributes, 'gender', 'Unknown'),
57
+ "Hair Color": getattr(attributes, 'hair_color', 'Unknown'),
58
+ "Eye Color": getattr(attributes, 'eye_color', 'Unknown'),
59
+ "Confidence": getattr(attributes, 'confidence_score', 0.0)
60
+ }
61
+
62
+ for key, value in attr_dict.items():
63
+ if key == "Confidence" or "Score" in key:
64
+ formatted_output += f"**{key}:** {value:.3f}\n"
65
+ else:
66
+ formatted_output += f"**{key}:** {value}\n"
67
+
68
+ json_output = json.dumps(attr_dict, indent=2)
69
+ stats = f"⚑ Processing Time: {processing_time:.2f}s\nπŸ€– Mode: RL Pipeline\nβœ… Status: Success"
70
+
71
+ else:
72
+ # Fallback mode with basic analysis
73
+ processing_time = time.time() - start_time
74
+
75
+ # Simple mock attributes
76
+ attr_dict = {
77
+ "Age": "Young Adult",
78
+ "Gender": "Unknown",
79
+ "Hair Color": "Unknown",
80
+ "Eye Color": "Unknown",
81
+ "Confidence": 0.5
82
+ }
83
+
84
+ formatted_output = "**🎭 Character Attributes (Fallback Mode):**\n\n"
85
+ for key, value in attr_dict.items():
86
+ if key == "Confidence":
87
+ formatted_output += f"**{key}:** {value:.3f}\n"
88
+ else:
89
+ formatted_output += f"**{key}:** {value}\n"
90
+
91
+ json_output = json.dumps(attr_dict, indent=2)
92
+ stats = f"⚑ Processing Time: {processing_time:.2f}s\nπŸ”„ Mode: Fallback\n⚠️ Status: Limited functionality"
93
+
94
+ return formatted_output, json_output, stats
95
+
96
+ except Exception as e:
97
+ error_msg = f"❌ Error processing image: {str(e)}"
98
+ logger.error(error_msg)
99
+
100
+ error_dict = {
101
+ "error": str(e),
102
+ "status": "error"
103
+ }
104
+ return error_msg, json.dumps(error_dict, indent=2), "❌ Processing failed"
105
+
106
+ def create_interface():
107
+ app = SimpleCharacterApp()
108
+
109
+ with gr.Blocks(title="RL Character Extraction") as interface:
110
+ gr.Markdown("""
111
+ # 🎭 RL-Enhanced Character Attribute Extraction
112
+
113
+ Upload a character image to extract detailed attributes using our RL-powered pipeline.
114
+ """)
115
+
116
+ with gr.Row():
117
+ with gr.Column():
118
+ image_input = gr.Image(
119
+ type="pil",
120
+ label="πŸ“Έ Upload Character Image"
121
+ )
122
+
123
+ extract_btn = gr.Button(
124
+ "πŸš€ Extract Attributes",
125
+ variant="primary"
126
+ )
127
+
128
+ with gr.Column():
129
+ formatted_output = gr.Markdown(
130
+ label="πŸ“‹ Extracted Attributes",
131
+ value="Upload an image and click 'Extract Attributes' to see results."
132
+ )
133
+
134
+ stats_output = gr.Textbox(
135
+ label="πŸ“Š Processing Stats",
136
+ lines=3
137
+ )
138
+
139
+ json_output = gr.Code(
140
+ label="πŸ“„ JSON Output",
141
+ language="json"
142
+ )
143
+
144
+ extract_btn.click(
145
+ fn=app.extract_attributes,
146
+ inputs=[image_input],
147
+ outputs=[formatted_output, json_output, stats_output]
148
+ )
149
+
150
+ return interface
151
+
152
+ def main():
153
+ logger.info("πŸš€ Starting Simple Character Attribute Extraction Interface...")
154
+
155
+ interface = create_interface()
156
+ port = int(os.environ.get("PORT", 7860))
157
+
158
+ interface.launch(
159
+ server_name="127.0.0.1",
160
+ server_port=port,
161
+ share=False,
162
+ show_error=True
163
+ )
164
+
165
+ if __name__ == "__main__":
166
+ main()
src/character_pipeline.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
7
  from PIL import Image
8
  import asyncio
9
 
10
- from pipeline import (
11
  Pipeline,
12
  PipelineStage,
13
  CharacterAttributes,
@@ -25,7 +25,7 @@ from pipeline import (
25
  DistributedProcessor,
26
  AdvancedCacheManager,
27
  )
28
- from rl_pipeline_integration import create_rl_enhanced_pipeline, ProductionRLPipeline
29
 
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
 
7
  from PIL import Image
8
  import asyncio
9
 
10
+ from src.pipeline import (
11
  Pipeline,
12
  PipelineStage,
13
  CharacterAttributes,
 
25
  DistributedProcessor,
26
  AdvancedCacheManager,
27
  )
28
+ from src.rl_pipeline_integration import create_rl_enhanced_pipeline, ProductionRLPipeline
29
 
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
src/rl_pipeline_integration.py CHANGED
@@ -6,11 +6,17 @@ import torch
6
  from pathlib import Path
7
  import json
8
  import time
9
- from rl_orchestrator import RLOrchestrator, StateVector
10
- from rl_trainer import train_rl_pipeline
11
- from pipeline.base import CharacterAttributes, ProcessingResult
12
- from pipeline.input_loader import DatasetItem
13
- import ray
 
 
 
 
 
 
14
 
15
  class ProductionRLPipeline:
16
  def __init__(self, model_path: Optional[str] = None, enable_training: bool = False):
@@ -24,6 +30,20 @@ class ProductionRLPipeline:
24
  "avg_cost": 0.0,
25
  "success_rate": 0.0
26
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  async def extract_attributes_rl(self, image: Union[str, Path, Image.Image],
29
  tags: Optional[str] = None,
 
6
  from pathlib import Path
7
  import json
8
  import time
9
+ from src.rl_orchestrator import RLOrchestrator, StateVector
10
+ from src.rl_trainer import train_rl_pipeline
11
+ from src.pipeline.base import CharacterAttributes, ProcessingResult
12
+ from src.pipeline.input_loader import DatasetItem
13
+ # import ray # Disabled to avoid GPU resource conflicts
14
+ try:
15
+ import ray
16
+ RAY_AVAILABLE = True
17
+ except ImportError:
18
+ RAY_AVAILABLE = False
19
+ ray = None
20
 
21
  class ProductionRLPipeline:
22
  def __init__(self, model_path: Optional[str] = None, enable_training: bool = False):
 
30
  "avg_cost": 0.0,
31
  "success_rate": 0.0
32
  }
33
+
34
+ # Initialize Ray for distributed processing (if available)
35
+ if RAY_AVAILABLE and not ray.is_initialized():
36
+ try:
37
+ ray.init(ignore_reinit_error=True, log_to_driver=False, num_cpus=2, num_gpus=0)
38
+ self.use_ray = True
39
+ except Exception as e:
40
+ print(f"Ray initialization failed: {e}. Running without distributed processing.")
41
+ self.use_ray = False
42
+ elif not RAY_AVAILABLE:
43
+ print("Ray not available. Running without distributed processing.")
44
+ self.use_ray = False
45
+ else:
46
+ self.use_ray = True
47
 
48
  async def extract_attributes_rl(self, image: Union[str, Path, Image.Image],
49
  tags: Optional[str] = None,