AlsuGibadullina commited on
Commit
5297a41
·
verified ·
1 Parent(s): aeda505

Update src/orchestrator.py

Browse files
Files changed (1) hide show
  1. src/orchestrator.py +45 -13
src/orchestrator.py CHANGED
@@ -1,16 +1,22 @@
1
  from dataclasses import dataclass, asdict
2
  from typing import Dict, Any, Optional
3
 
4
- from .config import RunConfig
5
- from .agents import AnalyzerAgent, RefactorAgent, CriticAgent, AgentResult
6
  from .tasks import TaskContext, build_analyzer_prompt, build_refactor_prompt, build_critic_prompt
 
 
 
 
 
 
7
 
8
 
9
  @dataclass
10
  class RunTrace:
11
  task: str
12
  input_requirements: str
13
- image_attached: bool
14
  analyzer: AgentResult
15
  refactor: AgentResult
16
  critic: AgentResult
@@ -19,7 +25,7 @@ class RunTrace:
19
  return {
20
  "task": self.task,
21
  "input_requirements": self.input_requirements,
22
- "image_attached": self.image_attached,
23
  "analyzer": asdict(self.analyzer),
24
  "refactor": asdict(self.refactor),
25
  "critic": asdict(self.critic),
@@ -35,22 +41,48 @@ class Orchestrator:
35
  self.refactor = RefactorAgent(cfg.refactor)
36
  self.critic = CriticAgent(cfg.critic)
37
 
38
- def run(self, requirements_text: str, image_path: Optional[str] = None) -> RunTrace:
39
- has_image = bool(image_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- p1 = build_analyzer_prompt(requirements_text, has_image, self.ctx)
42
- r1 = self.analyzer.run(p1, image_path=image_path)
43
 
44
- p2 = build_refactor_prompt(requirements_text, r1.output, has_image, self.ctx)
45
- r2 = self.refactor.run(p2, image_path=image_path)
46
 
47
- p3 = build_critic_prompt(requirements_text, r2.output, has_image, self.ctx)
48
- r3 = self.critic.run(p3, image_path=image_path)
49
 
50
  return RunTrace(
51
  task=self.ctx.name,
52
  input_requirements=requirements_text,
53
- image_attached=has_image,
54
  analyzer=r1,
55
  refactor=r2,
56
  critic=r3,
 
1
  from dataclasses import dataclass, asdict
2
  from typing import Dict, Any, Optional
3
 
4
+ from .config import RunConfig, ModelSpec
5
+ from .agents import AnalyzerAgent, RefactorAgent, CriticAgent, AgentResult, BaseAgent
6
  from .tasks import TaskContext, build_analyzer_prompt, build_refactor_prompt, build_critic_prompt
7
+ from .backends import HFInferenceAPIBackend
8
+
9
+ try:
10
+ from PIL import Image
11
+ except Exception:
12
+ Image = None
13
 
14
 
15
  @dataclass
16
  class RunTrace:
17
  task: str
18
  input_requirements: str
19
+ diagram_context: str
20
  analyzer: AgentResult
21
  refactor: AgentResult
22
  critic: AgentResult
 
25
  return {
26
  "task": self.task,
27
  "input_requirements": self.input_requirements,
28
+ "diagram_context": self.diagram_context,
29
  "analyzer": asdict(self.analyzer),
30
  "refactor": asdict(self.refactor),
31
  "critic": asdict(self.critic),
 
41
  self.refactor = RefactorAgent(cfg.refactor)
42
  self.critic = CriticAgent(cfg.critic)
43
 
44
+ # Dedicated models for diagram extraction (free/open on HF Inference API)
45
+ # OCR: TrOCR, Caption: BLIP
46
+ self.ocr_backend = HFInferenceAPIBackend("microsoft/trocr-base-printed")
47
+ self.caption_backend = HFInferenceAPIBackend("Salesforce/blip-image-captioning-large")
48
+
49
+ def _extract_diagram_context(self, image: Optional["Image.Image"]) -> str:
50
+ if image is None:
51
+ return ""
52
+
53
+ parts = []
54
+ try:
55
+ ocr = self.ocr_backend.image_to_text(image)
56
+ if ocr and ocr.strip():
57
+ parts.append("OCR (текст на изображении):\n" + ocr.strip())
58
+ except Exception as e:
59
+ parts.append(f"OCR: ошибка ({type(e).__name__})")
60
+
61
+ try:
62
+ cap = self.caption_backend.image_to_text(image)
63
+ if cap and cap.strip():
64
+ parts.append("Описание изображения:\n" + cap.strip())
65
+ except Exception as e:
66
+ parts.append(f"Caption: ошибка ({type(e).__name__})")
67
+
68
+ return "\n\n".join(parts).strip()
69
+
70
+ def run(self, requirements_text: str, image: Optional["Image.Image"] = None) -> RunTrace:
71
+ diagram_context = self._extract_diagram_context(image)
72
 
73
+ p1 = build_analyzer_prompt(requirements_text, diagram_context, self.ctx)
74
+ r1 = self.analyzer.run(p1)
75
 
76
+ p2 = build_refactor_prompt(requirements_text, r1.output, diagram_context, self.ctx)
77
+ r2 = self.refactor.run(p2)
78
 
79
+ p3 = build_critic_prompt(requirements_text, r2.output, diagram_context, self.ctx)
80
+ r3 = self.critic.run(p3)
81
 
82
  return RunTrace(
83
  task=self.ctx.name,
84
  input_requirements=requirements_text,
85
+ diagram_context=diagram_context,
86
  analyzer=r1,
87
  refactor=r2,
88
  critic=r3,