sababishraq commited on
Commit
dbf9d04
·
verified ·
1 Parent(s): 6a70fee

Upload 5 files

Browse files
Dockerfile (1) ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python base image with CUDA support
2
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ DEBIAN_FRONTEND=noninteractive \
7
+ PIP_NO_CACHE_DIR=1 \
8
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
9
+ HF_HOME=/app/.cache/huggingface \
10
+ STREAMLIT_SERVER_HEADLESS=true \
11
+ STREAMLIT_SERVER_FILE_WATCHER_TYPE=none \
12
+ STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \
13
+ STREAMLIT_CONFIG_DIR=/app/.streamlit \
14
+ TRITON_CACHE_DIR=/app/.triton \
15
+ HOME=/app
16
+
17
+ # Install system dependencies including Python 3.10
18
+ RUN apt-get update && apt-get install -y \
19
+ python3.10 \
20
+ python3-pip \
21
+ git \
22
+ wget \
23
+ curl \
24
+ && rm -rf /var/lib/apt/lists/*
25
+
26
+ # Create symlink for python3 if needed
27
+ RUN ln -sf /usr/bin/python3.10 /usr/bin/python3
28
+
29
+ # Set working directory
30
+ WORKDIR /app
31
+
32
+ # Copy requirements first for better caching
33
+ COPY requirements.txt .
34
+
35
+ # Install Python dependencies
36
+ RUN pip3 install --no-cache-dir -r requirements.txt
37
+
38
+ # Install playwright browsers
39
+ RUN playwright install --with-deps chromium
40
+
41
+ # Copy agent code and Streamlit app
42
+ COPY agent_azure_vlm_tools.py .
43
+ COPY streamlit_app.py .
44
+ COPY .streamlit/config.toml /app/.streamlit/
45
+ # Note: .env not copied - use HF Spaces secrets instead
46
+
47
+ # Create cache and output directories with proper permissions
48
+ RUN mkdir -p Outputs /app/.cache/huggingface/hub /app/.streamlit /app/.triton && \
49
+ chmod -R 777 /app/.cache /app/Outputs /app/.streamlit /app/.triton
50
+
51
+ # Expose port (Streamlit default)
52
+ EXPOSE 7860
53
+
54
+ # Health check
55
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
56
+ CMD curl -f http://localhost:7860/_stcore/health || exit 1
57
+
58
+ # Run Streamlit
59
+ CMD ["streamlit", "run", "streamlit_app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.headless=true"]
agent_azure_vlm_tools (2).py ADDED
@@ -0,0 +1,1381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import base64
5
+ import mimetypes
6
+ import argparse
7
+ import pathlib
8
+ import io
9
+ import requests
10
+ from typing import List, Dict, Any, Optional, TypedDict, Literal, Tuple
11
+ from dataclasses import dataclass, field
12
+ from io import BytesIO
13
+ from PIL import Image
14
+ from uuid import uuid4
15
+
16
+ # --- LangGraph / LangChain ---
17
+ from langgraph.graph import StateGraph, END
18
+ from langgraph.checkpoint.memory import MemorySaver
19
+
20
+ # --- OpenAI / Azure ---
21
+ from openai import AzureOpenAI, OpenAI
22
+ from dotenv import load_dotenv
23
+
24
+ # --- HF Transformers & Diffusers (Local VLM) ---
25
+ import torch
26
+ from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, pipeline, BitsAndBytesConfig
27
+ import transformers.utils as _hf_utils
28
+ if not hasattr(_hf_utils, "FLAX_WEIGHTS_NAME"):
29
+ _hf_utils.FLAX_WEIGHTS_NAME = "flax_model.msgpack"
30
+ from diffusers import DiffusionPipeline
31
+ try:
32
+ import transformers.models.auto.video_processing_auto as _video_auto
33
+ if getattr(_video_auto, "VIDEO_PROCESSOR_MAPPING_NAMES", None) is None:
34
+ _video_auto.VIDEO_PROCESSOR_MAPPING_NAMES = {}
35
+ except Exception as _video_err:
36
+ print(f"Warning: unable to patch video processor registry: {_video_err}")
37
+
38
+ # --- Playwright (for screenshots) ---
39
+ from playwright.sync_api import sync_playwright
40
+
41
+ # --- Load Environment Variables ---
42
+ load_dotenv()
43
+
44
+ # ----------------------------------------------------------------------
45
+ # ----------------------------------------------------------------------
46
+ # ## SECTION 1: UNIFIED MODEL MANAGER (Singleton)
47
+ #
48
+ # Manages loading all models (Azure, Local VLM, Generator)
49
+ # to ensure they are only loaded into memory once.
50
+ # ----------------------------------------------------------------------
51
+ # ----------------------------------------------------------------------
52
+
53
+ # --- Configs from all files ---
54
+ QWEN_VL_MODEL_NAME = os.getenv("QWEN_VL_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct")
55
+ SD_GENERATOR_MODEL = os.getenv("SD_GENERATOR_MODEL", "segmind/tiny-sd")
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
58
+
59
+ class ModelManager:
60
+ """Manages loading all models and API clients at startup."""
61
+ _instance = None
62
+
63
+ def __new__(cls, *args, **kwargs):
64
+ if not cls._instance:
65
+ cls._instance = super(ModelManager, cls).__new__(cls)
66
+ return cls._instance
67
+
68
+ def __init__(self):
69
+ if not hasattr(self, 'vlm_model'): # Initialize only once
70
+ print("Initializing and loading all models and clients...")
71
+
72
+ # 1. Configure Azure Client
73
+ self.AZURE_ENDPOINT = os.getenv("ENDPOINT_URL", "")
74
+ self.AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "")
75
+ if not self.AZURE_API_KEY or not self.AZURE_ENDPOINT:
76
+ print("Warning: AZURE_OPENAI_API_KEY or ENDPOINT_URL not set.")
77
+ else:
78
+ self.azure_client = AzureOpenAI(
79
+ azure_endpoint=self.AZURE_ENDPOINT,
80
+ api_key=self.AZURE_API_KEY,
81
+ api_version="2024-10-21"
82
+ )
83
+ print("AzureOpenAI client loaded.")
84
+
85
+ # 2. Configure OpenAI Client (for edit_node_tool)
86
+ try:
87
+ self.OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
88
+ self.openai_client = OpenAI(api_key=self.OPENAI_API_KEY)
89
+ print("OpenAI client loaded.")
90
+ except KeyError:
91
+ print("Warning: OPENAI_API_KEY not set. GPT editor tool will not work.")
92
+ self.openai_client = None
93
+
94
+ # 3. Configure and load the Local VLM (Qwen)
95
+ print(f"Loading local VLM: {QWEN_VL_MODEL_NAME}...")
96
+ quantization_config = BitsAndBytesConfig(
97
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
98
+ )
99
+ self.vlm_processor = AutoProcessor.from_pretrained(QWEN_VL_MODEL_NAME, trust_remote_code=True)
100
+ self.vlm_model = AutoModelForVision2Seq.from_pretrained(
101
+ QWEN_VL_MODEL_NAME,
102
+ torch_dtype=DTYPE,
103
+ device_map="auto",
104
+ quantization_config=quantization_config,
105
+ trust_remote_code=True
106
+ )
107
+ print("Local VLM (Qwen) loaded.")
108
+
109
+ # 4. Configure and load the Generator
110
+ print(f"Loading image generator: {SD_GENERATOR_MODEL}...")
111
+ self.generator_pipe = DiffusionPipeline.from_pretrained(
112
+ SD_GENERATOR_MODEL, torch_dtype=DTYPE
113
+ )
114
+ self.generator_pipe.enable_model_cpu_offload()
115
+ print("Generator loaded.")
116
+
117
+ print("All models and clients loaded and ready.")
118
+
119
+ def get_azure_client(self) -> AzureOpenAI:
120
+ if not hasattr(self, 'azure_client'):
121
+ raise RuntimeError("Azure client not initialized. Set AZURE_OPENAI_API_KEY and ENDPOINT_URL.")
122
+ return self.azure_client
123
+
124
+ def get_openai_client(self) -> OpenAI:
125
+ if not hasattr(self, 'openai_client') or self.openai_client is None:
126
+ raise RuntimeError("OpenAI client not initialized. Set OPENAI_API_KEY.")
127
+ return self.openai_client
128
+
129
+ # --- VLM Chat (in asset tool) ---
130
+ def chat_vlm(self, messages, temperature=0.2, max_new_tokens=2048):
131
+ gen_kwargs = {"do_sample": temperature > 0, "max_new_tokens": max_new_tokens}
132
+ if temperature > 0:
133
+ gen_kwargs["temperature"] = temperature
134
+
135
+ inputs = self.vlm_processor.apply_chat_template(
136
+ messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
137
+ ).to(self.vlm_model.device)
138
+
139
+ with torch.no_grad():
140
+ output_ids = self.vlm_model.generate(**inputs, **gen_kwargs)
141
+
142
+ gen_only = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], output_ids)]
143
+ return self.vlm_processor.batch_decode(gen_only, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
144
+
145
+ def chat_llm(self, prompt: str):
146
+ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
147
+ return self.chat_vlm(messages, temperature=0.1, max_new_tokens=1024)
148
+
149
+ # --- Generator (from asset_tool) ---
150
+ def generate_image(self, prompt: str) -> Image.Image:
151
+ print(f"Generating image with prompt: '{prompt}'")
152
+ return self.generator_pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
153
+
154
+ # --- Azure Chat (from agent_azure_vlm) ---
155
+ def chat_complete_azure(self, deployment: str, messages: List[Dict[str, Any]],
156
+ temperature: float, max_tokens: int) -> str:
157
+ client = self.get_azure_client()
158
+ resp = client.chat.completions.create(
159
+ model=deployment,
160
+ messages=messages,
161
+ temperature=temperature,
162
+ max_tokens=max_tokens,
163
+ )
164
+ return (resp.choices[0].message.content or "").strip()
165
+
166
+ # --- Initialize models ONCE ---
167
+ models = ModelManager()
168
+
169
+ # ----------------------------------------------------------------------
170
+ # ----------------------------------------------------------------------
171
+ # ## SECTION 2: ASSET-FINDING TOOL
172
+ #
173
+ # This is the self-contained graph for finding/generating assets.
174
+ # It will be used as a tool by the "Brain" AND by the Azure pipeline.
175
+ # ----------------------------------------------------------------------
176
+ # ----------------------------------------------------------------------
177
+
178
+ ### --- Utilities from asset_tool.py ---
179
+ def load_image(path: str) -> Image.Image:
180
+ return Image.open(path).convert("RGB")
181
+
182
+ def b64img(pil_img: Image.Image) -> str:
183
+ buf = io.BytesIO()
184
+ pil_img.save(buf, format="PNG")
185
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
186
+
187
+ ### --- State from asset_tool.py ---
188
+ class AssetGraphState(TypedDict):
189
+ """State for the asset-finding subgraph."""
190
+ instructions: str
191
+ bounding_box: Tuple[int, int]
192
+ search_query: str
193
+ found_image_url: Optional[str]
194
+ final_asset_path: Optional[str]
195
+
196
+ ### --- Nodes from asset_tool.py ---
197
+ def asset_prepare_search_query_node(state: AssetGraphState) -> dict:
198
+ print("---(Asset Tool) NODE: Prepare Search Query---")
199
+ prompt = f"""You are an expert at refining search queries. Extract only the essential visual keywords.
200
+ **CRITICAL INSTRUCTIONS:**
201
+ - DO NOT include words related to licensing.
202
+ - DO NOT include quotation marks.
203
+ User's request: "{state['instructions']}"
204
+ Respond with ONLY the refined search query."""
205
+ raw_query = models.chat_llm(prompt)
206
+ search_query = raw_query.strip().replace('"', '')
207
+ print(f"Prepared search query: '{search_query}'")
208
+ return {"search_query": search_query}
209
+
210
+ def asset_generate_image_node(state: AssetGraphState) -> dict:
211
+ print("---(Asset Tool) NODE: Generate Image---")
212
+ prompt = state["instructions"]
213
+ generated_image = models.generate_image(prompt)
214
+ output_dir = pathlib.Path("Outputs/Assets")
215
+ output_dir.mkdir(parents=True, exist_ok=True)
216
+ filename = f"generated_{uuid4()}.png"
217
+ full_save_path = output_dir / filename
218
+ generated_image.save(full_save_path)
219
+ print(f"Image generated and saved to {full_save_path}")
220
+ html_path = pathlib.Path("Assets") / filename
221
+ final_asset_path = str(html_path.as_posix())
222
+ return {"final_asset_path": final_asset_path}
223
+
224
+ def asset_download_and_resize_node(state: AssetGraphState) -> dict:
225
+ print("---(Asset Tool) NODE: Download and Resize---")
226
+ image_url = state.get("found_image_url")
227
+ try:
228
+ response = requests.get(image_url, timeout=10)
229
+ response.raise_for_status()
230
+ img = Image.open(BytesIO(response.content))
231
+ img.thumbnail(state['bounding_box'])
232
+ output_dir = pathlib.Path("Outputs/Assets")
233
+ output_dir.mkdir(parents=True, exist_ok=True)
234
+ filename = f"asset_{uuid4()}.png"
235
+ full_save_path = output_dir / filename
236
+ img.save(full_save_path)
237
+ print(f"Image saved and resized to {full_save_path}")
238
+ html_path = pathlib.Path("Assets") / filename
239
+ final_asset_path = str(html_path.as_posix())
240
+ return {"final_asset_path": final_asset_path}
241
+ except Exception as e:
242
+ print(f"Error processing image: {e}")
243
+ return {"final_asset_path": None}
244
+
245
+ def asset_route_after_search(state: AssetGraphState) -> str:
246
+ if state.get("found_image_url"):
247
+ return "download_and_resize"
248
+ else:
249
+ print("Search failed. Routing to generate a new image.")
250
+ return "generate_image"
251
+
252
+ def asset_pexels_search_node(state: AssetGraphState) -> dict:
253
+ print("---(Asset Tool) TOOL: Searching Pexels---")
254
+ api_key = os.getenv("PEXELS_API_KEY")
255
+ search_query = state.get("search_query")
256
+ if not api_key:
257
+ print("Warning: PEXELS_API_KEY not set. Skipping search.")
258
+ return {"found_image_url": None}
259
+ if not search_query:
260
+ return {"found_image_url": None}
261
+
262
+ headers = {"Authorization": api_key}
263
+ params = {"query": search_query, "per_page": 1}
264
+ try:
265
+ response = requests.get("https://api.pexels.com/v1/search", headers=headers, params=params, timeout=10)
266
+ response.raise_for_status()
267
+ if response.json().get('photos'):
268
+ image_url = response.json()['photos'][0]['src']['original']
269
+ print(f"Found a candidate image: {image_url}")
270
+ return {"found_image_url": image_url}
271
+ except requests.exceptions.RequestException as e:
272
+ print(f"Pexels API Error: {e}")
273
+ return {"found_image_url": None}
274
+
275
+ ### --- Graph Builder from asset_tool.py ---
276
+ def build_asset_graph():
277
+ workflow = StateGraph(AssetGraphState)
278
+ workflow.add_node("prepare_search_query", asset_prepare_search_query_node)
279
+ workflow.add_node("pexels_search", asset_pexels_search_node)
280
+ workflow.add_node("generate_image", asset_generate_image_node)
281
+ workflow.add_node("download_and_resize", asset_download_and_resize_node)
282
+ workflow.set_entry_point("prepare_search_query")
283
+ workflow.add_edge("prepare_search_query", "pexels_search")
284
+ workflow.add_conditional_edges("pexels_search", asset_route_after_search)
285
+ workflow.add_edge("generate_image", END)
286
+ workflow.add_edge("download_and_resize", END)
287
+ return workflow.compile()
288
+
289
+ # --- Compile the graph ---
290
+ asset_agent_app = build_asset_graph()
291
+
292
+ # ----------------------------------------------------------------------
293
+ # ----------------------------------------------------------------------
294
+ # ## SECTION 3: CODE EDITOR TOOL
295
+ #
296
+ # This is the self-contained graph for editing HTML.
297
+ # It will be used as a tool by the "Brain".
298
+ # ----------------------------------------------------------------------
299
+ # ----------------------------------------------------------------------
300
+
301
+ class CodeEditorState(TypedDict):
302
+ html_code: str
303
+ user_request: str
304
+ model_choice: Literal["gpt-4o-mini-2", "qwen-local"]
305
+ messages: list[str]
306
+
307
+ EDITOR_SYSTEM_PROMPT = """
308
+ You are an expert senior web developer specializing in HTML, CSS, and JavaScript.
309
+ Your task is to take an existing HTML file, a user's request for changes, and to output the *new, complete, and updated HTML file*.
310
+
311
+ **CRITICAL RULES:**
312
+ 1. **Output ONLY the Code:** Your entire response MUST be *only* the raw, updated HTML code.
313
+ 2. **No Conversation:** Do NOT include "Here is the updated code:", "I have made the following changes:", or any other explanatory text, comments, or markdown formatting.
314
+ 3. **Return the Full File:** Always return the complete HTML file, from `<!DOCTYPE html>` to `</html>`, incorporating the requested changes. Do not return just a snippet.
315
+ """
316
+
317
+ def _clean_llm_output(code: str) -> str:
318
+ """Removes common markdown formatting."""
319
+ code = code.strip()
320
+ if code.startswith("```html"):
321
+ code = code[7:]
322
+ if code.endswith("```"):
323
+ code = code[:-3]
324
+ return code.strip()
325
+
326
+ def _call_gpt_editor(html_code: str, user_request: str, model: str) -> str:
327
+ """Uses OpenAI (GPT) model."""
328
+ user_prompt = f"**User Request:**\n{user_request}\n\n**Original HTML Code:**\n```html\n{html_code}\n```\n\n**Your updated HTML Code:**"
329
+ try:
330
+ client = models.get_openai_client()
331
+ response = client.chat.completions.create(
332
+ model=model,
333
+ messages=[
334
+ {"role": "system", "content": EDITOR_SYSTEM_PROMPT},
335
+ {"role": "user", "content": user_prompt}
336
+ ],
337
+ temperature=0.0,
338
+ max_tokens=8192,
339
+ )
340
+ edited_code = response.choices[0].message.content
341
+ return _clean_llm_output(edited_code)
342
+ except Exception as e:
343
+ print(f"Error calling OpenAI API: {e}")
344
+ return f"\n{html_code}"
345
+
346
+ def _call_qwen_editor(html_code: str, user_request: str) -> str:
347
+ """Uses Local Qwen VLM."""
348
+ user_prompt = f"**User Request:**\n{user_request}\n\n**Original HTML Code:**\n```html\n{html_code}\n```\n\n**Your updated HTML Code:**"
349
+ messages = [
350
+ {"role": "system", "content": [{"type": "text", "text": EDITOR_SYSTEM_PROMPT}]},
351
+ {"role": "user", "content": [{"type": "text", "text": user_prompt}]}
352
+ ]
353
+ try:
354
+ edited_code = models.chat_vlm(messages, temperature=0.0, max_new_tokens=8192)
355
+ return _clean_llm_output(edited_code)
356
+ except Exception as e:
357
+ print(f"Error calling local Qwen VLM: {e}")
358
+ return f"\n{html_code}"
359
+
360
+ def node_edit_code(state: CodeEditorState) -> dict:
361
+ print("---(Edit Tool) NODE: Edit Code---")
362
+ html_code, user_request, model_choice = state['html_code'], state['user_request'], state['model_choice']
363
+ messages = state.get('messages', [])
364
+
365
+ if not user_request:
366
+ return {"messages": messages + ["No user request provided. Skipping edit."]}
367
+
368
+ try:
369
+ if "gpt" in model_choice.lower():
370
+ new_html_code = _call_gpt_editor(html_code, user_request, model_choice)
371
+ else:
372
+ new_html_code = _call_qwen_editor(html_code, user_request)
373
+
374
+ msg = f"Code edit complete using {model_choice}."
375
+ print(msg)
376
+ return {"html_code": new_html_code, "user_request": "", "messages": messages + [msg]}
377
+ except Exception as e:
378
+ error_msg = f"Error in code editing node: {e}"
379
+ print(error_msg)
380
+ return {"html_code": html_code, "messages": messages + [error_msg]}
381
+
382
+ def build_edit_graph():
383
+ workflow = StateGraph(CodeEditorState)
384
+ workflow.add_node("edit_code", node_edit_code)
385
+ workflow.set_entry_point("edit_code")
386
+ workflow.add_edge("edit_code", END)
387
+ return workflow.compile(checkpointer=MemorySaver())
388
+
389
+ # --- Compile the graph ---
390
+ edit_agent_app = build_edit_graph()
391
+
392
+ # ----------------------------------------------------------------------
393
+ # ----------------------------------------------------------------------
394
+ # ## SECTION 4: AZURE VLM PIPELINE (RE-ORDERED)
395
+ #
396
+ # This pipeline is reordered to be much faster.
397
+ # 1. CodeGen runs FIRST, creating placeholders.
398
+ # 2. A fast regex parser finds the placeholders.
399
+ # 3. Asset search runs.
400
+ # 4. A patcher node inserts the asset paths.
401
+ # 5. Scoring & Refinement run as normal.
402
+ #
403
+ # This completely removes the slow local VLM call from this graph.
404
+ # ----------------------------------------------------------------------
405
+ # ----------------------------------------------------------------------
406
+
407
+ ## --- Helpers ---
408
+ _SCORE_KEYS = ["aesthetics","completeness","layout_fidelity","text_legibility","visual_balance"]
409
+
410
+ def _section(text: str, name: str) -> str:
411
+ pat = rf"{name}:\s*\n(.*?)(?=\n[A-Z_]+:\s*\n|\Z)"
412
+ m = re.search(pat, text, flags=re.S)
413
+ return m.group(1).strip() if m else ""
414
+
415
+ def _score_val(block: str, key: str, default: int = 0) -> int:
416
+ m = re.search(rf"{key}\s*:\s*(-?\d+)", block, flags=re.I)
417
+ try:
418
+ return int(m.group(1)) if m else default
419
+ except:
420
+ return default
421
+
422
+ def encode_image_to_data_url(path: str) -> str:
423
+ mime = mimetypes.guess_type(path)[0] or "image/png"
424
+ with open(path, "rb") as f:
425
+ b64 = base64.b64encode(f.read()).decode("utf-8")
426
+ return f"data:{mime};base64,{b64}"
427
+
428
+ def extract_html(text: str) -> str:
429
+ m = re.search(r"```html(.*?)```", text, flags=re.S|re.I)
430
+ if m:
431
+ return m.group(1).strip()
432
+ i = text.lower().find("<html")
433
+ return text[i:].strip() if i != -1 else text.strip()
434
+
435
+ ### --- Prompts ---
436
+ RELDESC_SYSTEM = "You are a meticulous UI analyst who describes layouts as a single dense paragraph of relative relationships."
437
+ RELDESC_PROMPT = """
438
+ From the provided wireframe image, produce ONE detailed paragraph (no bullets, no lists, no headings, no JSON, no code)
439
+ that states the RELATIVE layout and styling of the page so a code generator can rebuild it.
440
+
441
+ Requirements for the paragraph:
442
+ - Mention the overall background color and the dominant text color (infer reasonable #HEX).
443
+ - Describe the NAV BAR first: its position (top row), left/center/right alignment of brand/logo and items, the item order,
444
+ which (if any) is emphasized/active, approximate pill/underline treatment, and colors for default vs active.
445
+ - Describe the CONTENT in reading order by rows:
446
+ * For each row, say HOW MANY items appear side-by-side and their approximate relative widths (e.g., equal thirds, 1/3–2/3).
447
+ * For “cards” or boxes: state the title text (or short descriptor), the presence of body text, any CTA (button/link) labels,
448
+ button shape (rounded/square), fill/outline, and inferred #HEX colors for surface, border, and CTA.
449
+ * Call out approximate spacing (e.g., tight/medium/roomy), gutters/gaps (px if you can), and typical radii (px).
450
+ - Describe the FOOTER last: alignment (center/left/right), text size relative to body, and background/text colors.
451
+ - Include one sentence on typography: font family category (system sans/serif), approximate base size (px), and headings hierarchy.
452
+ - Keep everything in ONE paragraph. Do not use line breaks except normal wrapping.
453
+
454
+ Return ONLY that single paragraph.
455
+ """
456
+ BRIEF_SYSTEM = "You are a senior product designer who converts wireframes into precise UI design briefs."
457
+ BRIEF_PROMPT = """
458
+ Using the RELATIVE LAYL'OUT DESCRIPTION (authoritative for relative structure) and the wireframe image,
459
+ write a **UI DESIGN BRIEF** using EXACTLY these section headings, in this order. Be concise but specific. Infer reasonable hex colors.
460
+ If there is any conflict, prefer the wireframe image but keep structure consistent with the relative description.
461
+
462
+ ### UI_SUMMARY
463
+ One paragraph that states page purpose and the major regions.
464
+
465
+ ### COLOR_PALETTE
466
+ List 6–10 colors as `name: #HEX` including background, surface/card, text, muted text, primary,
467
+ secondary/accent, link, button-default, button-active.
468
+
469
+ ### TYPOGRAPHY
470
+ Font family (system stack), base font-size (px), title sizes (h1/h2/h3 in px), and weight rules.
471
+ Line-heights.
472
+
473
+ ### LAYOUT
474
+ Container max-width (px), global padding (px), section gaps (px), and the overall structure
475
+ (header/nav, content rows/columns, footer). State **how many items appear side-by-side** in each row
476
+ and at which breakpoint they stack.
477
+
478
+ ### NAVBAR
479
+ Exact order of items, which one is ACTIVE, and pill styling (padding, radius, default vs active
480
+ background/text colors).
481
+
482
+ ### CARDS
483
+ For each card in the content row: title text, body summary, CTA label and style (button/link),
484
+ card padding, radius, shadow, spacing between title/body/cta.
485
+
486
+ ### RESPONSIVE_RULES
487
+ Breakpoints (sm/md/lg in px) and what changes at each (column stack rules, spacing adjustments).
488
+
489
+ ### SPACING_AND_BORDERS
490
+ Numbers (px) for margins, gaps, radii used across elements.
491
+
492
+ Output ONLY the brief text with those headings (no code fences, no JSON).
493
+ """
494
+
495
+ # *** NEW UPDATED CODE_PROMPT ***
496
+ CODE_SYSTEM = "You are a meticulous frontend engineer who writes clean, modern, responsive HTML+CSS."
497
+ CODE_PROMPT = """
498
+ Using the following **RELATIVE LAYOUT DESCRIPTION** and **UI DESIGN BRIEF**, generate a SINGLE, self-contained HTML document:
499
+
500
+ Requirements:
501
+ - Semantic tags: header/nav/main/section/footer.
502
+ - One <style> block; no external CSS/JS.
503
+ - Define CSS variables from the palette and use them consistently.
504
+ - Implement the layout: container max-width, gaps, grid columns, and stacking rules per breakpoints.
505
+ - **CRITICAL ASSET RULE**: If you need an image (logo, hero, card image, etc.), you MUST use a placeholder in this **exact** format:
506
+ <img src="placeholder" data-asset-id="a-unique-id-for-this-image" data-asset-description="a detailed description for an image search engine">
507
+ (Example: <img src="placeholder" data-asset-id="hero-image" data-asset-description="photo of a modern office building">)
508
+ - **DO NOT** use the `ASSET_PATHS` variable, it will be empty.
509
+ - Output ONLY valid HTML starting with <html> and ending with </html>.
510
+ """
511
+
512
+ SCORING_RUBRIC = r"""
513
+ You are an experienced front-end engineer. Compare two images: (A) the original wireframe, and (B) the generated HTML rendering,
514
+ and read the HTML/CSS code used for (B).
515
+
516
+ Return a PLAIN-TEXT report with the following sections EXACTLY in this order
517
+ (no JSON, no code fences around the whole report):
518
+
519
+ SCORES:
520
+ aesthetics: <0-10>
521
+ completeness: <0-10>
522
+ layout_fidelity: <0-10> # be harsh; row alignment, relative sizes and positions must match A
523
+ text_legibility: <0-10>
524
+ visual_balance: <0-10>
525
+ aggregate: <float> # mean of the five scores
526
+
527
+ ISSUES_TOP3:
528
+ - <short, specific issue 1>
529
+ - <issue 2>
530
+ - <issue 3>
531
+
532
+ LAYOUT_DIFFS:
533
+ - component: <nav|grid|card[1]|card[2]|footer>
534
+ a_bbox_pct: [x,y,w,h] # approx percentages (0–100) of page width/height in A
535
+ b_bbox_pct: [x,y,w,h] # same for B
536
+ fix: <one sentence with exact px/cols/gaps>
537
+
538
+ CSS_PATCH:
539
+ ```css
540
+ /* <= 40 lines, use existing selectors where possible; use px and hex colors */
541
+ .selector { property: value; }
542
+ /* ... */
543
+ ```
544
+
545
+ HTML_EDITS:
546
+ - <one edit per line; selector + action, e.g., "add-class .card --class=wide":
547
+ - <allowed actions: move-before, move-after, insert-before, insert-after, set-attr, replace-text, add-class, remove-class>
548
+
549
+ REGENERATE_PROMPT:
550
+ <1–4 lines with exact grid, gaps (px), radii (px), hex colors, and font sizes to rebuild if needed>
551
+
552
+ FEEDBACK:
553
+ <one dense paragraph prioritizing layout_fidelity with exact px/cols/gaps/hex values>
554
+ """
555
+ REFINE_SYSTEM = "You are a senior frontend engineer who strictly applies critique to improve HTML/CSS while matching the wireframe."
556
+ REFINE_PROMPT = """
557
+ You are given:
558
+ 1) (A) the original wireframe image
559
+ 2) The CURRENT HTML (single-file) that produced (B) the rendering
560
+ 3) A critique ("feedback") produced by a rubric-based comparison of A vs B
561
+
562
+ Task:
563
+ - Produce a NEW single-file HTML that addresses EVERY feedback point while staying faithful to A.
564
+ - Fix layout fidelity (columns, spacing, alignment), completeness (ensure all components in A exist),
565
+ typography/contrast for legibility, and overall aesthetics and balance.
566
+ - Keep it self-contained (inline <style>; no external CSS/JS).
567
+ - Output ONLY valid HTML starting with <html> and ending with </html>.
568
+ """
569
+
570
+ @dataclass
571
+ class CodeRefineState:
572
+ # CLI inputs
573
+ image_path: str
574
+ out_rel_desc: str
575
+ out_brief: str
576
+ out_html: str
577
+ vision_deployment: str
578
+ text_deployment: str
579
+ reldesc_tokens: int
580
+ brief_tokens: int
581
+ code_tokens: int
582
+ judge_tokens: int
583
+ temp: float
584
+ refine_max_iters: int
585
+ refine_threshold: int
586
+ shot_width: int
587
+ shot_height: int
588
+
589
+ # Runtime state
590
+ image_data_url: Optional[str] = None
591
+ rel_desc: Optional[str] = None
592
+ brief: Optional[str] = None
593
+ html: Optional[str] = None
594
+ current_iteration: int = 0
595
+ scores: Optional[Dict[str, Any]] = None
596
+ stop_refinement: bool = False
597
+
598
+ asset_plan: List[Dict[str, Any]] = field(default_factory=list)
599
+ asset_paths: Dict[str, str] = field(default_factory=dict)
600
+
601
+ messages: List[str] = field(default_factory=list)
602
+
603
+ def parse_text_report(report: str) -> Dict[str, Any]:
604
+ sb = _section(report, "SCORES")
605
+ scores = {k: _score_val(sb, k, 0) for k in _SCORE_KEYS}
606
+ m_agg = re.search(r"aggregate\s*:\s*([0-9]+(?:\.[0-9]+)?)", sb, flags=re.I)
607
+ aggregate = float(m_agg.group(1)) if m_agg else sum(scores.values())/5.0
608
+ css_patch = ""
609
+ css_match = re.search(r"CSS_PATCH:\s*```css\s+(.*?)\s+```", report, flags=re.S|re.I)
610
+ if css_match:
611
+ css_patch = css_match.group(1).strip()
612
+ html_edits = _section(report, "HTML_EDITS")
613
+ regenerate_prompt = _section(report, "REGENERATE_PROMPT")
614
+ feedback = _section(report, "FEEDBACK")
615
+ issues = _section(report, "ISSUES_TOP3")
616
+ layout_diffs = _section(report, "LAYOUT_DIFFS")
617
+ return {
618
+ "scores": scores, "aggregate": aggregate, "css_patch": css_patch, "html_edits": html_edits,
619
+ "regenerate_prompt": regenerate_prompt, "feedback": feedback, "issues_top3": issues,
620
+ "layout_diffs": layout_diffs, "raw": report,
621
+ }
622
+
623
+ def refine_with_feedback(vision_deployment: str, wireframe_image: str, current_html: str, feedback: str,
624
+ css_patch: str = "", html_edits: str = "", regenerate_prompt: str = "",
625
+ temperature: float = 0.12, max_tokens: int = 2200) -> str:
626
+ data_a = encode_image_to_data_url(wireframe_image)
627
+ refine_instructions = f"{REFINE_PROMPT.strip()}\n\nAPPLY THESE PATCHES EXACTLY:..." # (rest of prompt)
628
+ messages = [
629
+ {"role": "system", "content": REFINE_SYSTEM},
630
+ {"role": "user", "content": [
631
+ {"type": "image_url", "image_url": {"url": data_a}},
632
+ {"type": "text", "text": refine_instructions + "\n\nCURRENT_HTML:\n```html\n" + current_html + "\n```"}
633
+ ]},
634
+ ]
635
+ out = models.chat_complete_azure(vision_deployment, messages, temperature, max_tokens)
636
+ html = extract_html(out)
637
+ if "<html" not in html.lower():
638
+ html = f"<!DOCTYPE html>\n<html>\n<head><meta charset='utf-8'><title>Refined</title></head>\n<body>\n{html}\n</body>\n</html>"
639
+ return html
640
+
641
+ def node_stage0(state: CodeRefineState) -> CodeRefineState:
642
+ state.image_data_url = encode_image_to_data_url(state.image_path)
643
+ messages = [
644
+ {"role": "system", "content": RELDESC_SYSTEM},
645
+ {"role": "user", "content": [
646
+ {"type":"image_url", "image_url":{"url":state.image_data_url}},
647
+ {"type":"text", "text": RELDESC_PROMPT.strip()},
648
+ ]},
649
+ ]
650
+ state.rel_desc = models.chat_complete_azure(state.vision_deployment, messages, state.temp, state.reldesc_tokens)
651
+ pathlib.Path(state.out_rel_desc).parent.mkdir(parents=True, exist_ok=True)
652
+ with open(state.out_rel_desc, "w", encoding="utf-8") as f: f.write(state.rel_desc.strip())
653
+ state.messages.append("Stage0: Generated relative layout description.")
654
+ return state
655
+
656
+ def node_stage1(state: CodeRefineState) -> CodeRefineState:
657
+ messages = [
658
+ {"role": "system", "content": BRIEF_SYSTEM},
659
+ {"role": "user", "content": [
660
+ {"type":"image_url", "image_url":{"url":state.image_data_url}},
661
+ {"type":"text", "text": BRIEF_PROMPT.strip() + "\n\nRELATIVE LAYOUT DESCRIPTION:\n" + state.rel_desc.strip()},
662
+ ]},
663
+ ]
664
+ state.brief = models.chat_complete_azure(state.vision_deployment, messages, state.temp, state.brief_tokens)
665
+ pathlib.Path(state.out_brief).parent.mkdir(parents=True, exist_ok=True)
666
+ with open(state.out_brief, "w", encoding="utf-8") as f: f.write(state.brief)
667
+ state.messages.append("Stage1: Generated UI design brief.")
668
+ return state
669
+
670
+ def node_stage2(state: CodeRefineState) -> CodeRefineState:
671
+ messages = [
672
+ {"role": "system", "content": CODE_SYSTEM},
673
+ {"role": "user", "content": [
674
+ {"type":"text", "text": CODE_PROMPT.strip()},
675
+ {"type":"text", "text": "RELATIVE LAYOUT DESCRIPTION:\n" + state.rel_desc.strip()},
676
+ {"type":"text", "text": "UI DESIGN BRIEF:\n" + state.brief.strip()},
677
+ {"type":"text", "text": "ASSET_PATHS:\n{}"}, # Send empty
678
+ ]},
679
+ ]
680
+ raw = models.chat_complete_azure(state.text_deployment, messages, state.temp, state.code_tokens)
681
+ state.html = extract_html(raw)
682
+
683
+ # Note: We don't save the HTML yet, as it's not patched.
684
+ state.messages.append("Stage2: Generated HTML (with placeholders).")
685
+ return state
686
+
687
+ def node_plan_assets_from_html(state: CodeRefineState) -> CodeRefineState:
688
+ """
689
+ Parses the generated HTML for placeholders and builds the asset_plan.
690
+ This replaces the slow VLM planner.
691
+ """
692
+ print("---(Azure VLM) NODE: Planning assets from HTML placeholders---")
693
+ # Regex to find: <img ... data-asset-id="..." data-asset-description="...">
694
+ placeholder_regex = r'<img[^>]+data-asset-id="([^"]+)"[^>]+data-asset-description="([^"]+)"'
695
+
696
+ matches = re.findall(placeholder_regex, state.html, re.I)
697
+
698
+ asset_plan = []
699
+ for component_id, description in matches:
700
+ asset_plan.append({
701
+ "component_id": component_id,
702
+ "description": description,
703
+ "bounding_box": {"width": 512, "height": 512} # Use a default size
704
+ })
705
+
706
+ state.asset_plan = asset_plan
707
+ state.messages.append(f"Stage2.5: Planned {len(asset_plan)} assets from HTML.")
708
+ print(f"Asset plan: {asset_plan}")
709
+ return state
710
+
711
+ def node_stage1_find_assets(state: CodeRefineState) -> CodeRefineState:
712
+ print("---(Azure VLM) NODE: Finding Assets---")
713
+ if not state.asset_plan:
714
+ state.messages.append("Stage2.6: No assets to find.")
715
+ return state
716
+
717
+ current_asset_paths = {}
718
+ for asset_request in state.asset_plan:
719
+ component_id = asset_request.get('component_id')
720
+ desc = asset_request.get('description')
721
+ bbox = asset_request.get('bounding_box', {})
722
+ if not all([component_id, desc, bbox]): continue
723
+
724
+ print(f"-> Finding asset for '{component_id}': {desc}")
725
+ try:
726
+ width = int(bbox.get('width', 512))
727
+ height = int(bbox.get('height', 512))
728
+ except (ValueError, TypeError):
729
+ width, height = 512, 512
730
+
731
+ # Call the asset_agent_app
732
+ result = asset_agent_app.invoke({"instructions": desc, "bounding_box": (width, height)})
733
+
734
+ if final_path := result.get("final_asset_path"):
735
+ current_asset_paths[component_id] = final_path
736
+ msg = f"Asset resolved for {component_id}: {final_path}"
737
+ state.messages.append(msg); print(f" ✅ {msg}")
738
+ else:
739
+ msg = f"Asset process failed for {component_id}."
740
+ state.messages.append(msg); print(f" ❌ {msg}")
741
+
742
+ state.asset_paths = current_asset_paths
743
+ return state
744
+
745
+ # Patch HTML with found assets
746
+ def node_patch_html_with_assets(state: CodeRefineState) -> CodeRefineState:
747
+ """
748
+ Replaces the placeholders in the HTML with the paths from asset_paths.
749
+ """
750
+ print("---(Azure VLM) NODE: Patching HTML with assets---")
751
+ if not state.asset_paths:
752
+ state.messages.append("Stage2.7: No assets to patch.")
753
+ # Save the un-patched HTML
754
+ pathlib.Path(state.out_html).parent.mkdir(parents=True, exist_ok=True)
755
+ with open(state.out_html, "w", encoding="utf-8") as f: f.write(state.html)
756
+ state.messages.append(f"Saved un-patched HTML -> {state.out_html}")
757
+ return state
758
+
759
+ html = state.html
760
+ for component_id, new_path in state.asset_paths.items():
761
+ # Find the placeholder tag using the component_id and replace its src
762
+ # This regex finds the <img ... data-asset-id="component_id" ...>
763
+ # and replaces the src="placeholder" part.
764
+ patch_regex = rf'(<img[^>]+data-asset-id="{re.escape(component_id)}"[^>]+)src="placeholder"'
765
+ replace_with = rf'\1src="{new_path}"'
766
+
767
+ new_html, count = re.subn(patch_regex, replace_with, html, flags=re.I)
768
+ if count > 0:
769
+ html = new_html
770
+ state.messages.append(f"Patched {component_id} -> {new_path}")
771
+ else:
772
+ state.messages.append(f"Warning: Could not find placeholder for {component_id} to patch.")
773
+
774
+ state.html = html
775
+ # Save the *patched* HTML
776
+ pathlib.Path(state.out_html).parent.mkdir(parents=True, exist_ok=True)
777
+ with open(state.out_html, "w", encoding="utf-8") as f: f.write(state.html)
778
+ state.messages.append(f"Stage2.7: Saved patched HTML -> {state.out_html}")
779
+ return state
780
+
781
+ def node_stage3_score(state: CodeRefineState) -> CodeRefineState:
782
+ html_path = pathlib.Path(state.out_html)
783
+ shot_png_path = html_path.with_name(html_path.stem + f"_iter{state.current_iteration}.png")
784
+ with sync_playwright() as p:
785
+ browser = p.chromium.launch()
786
+ ctx = browser.new_context(viewport={"width": state.shot_width, "height": state.shot_height}, device_scale_factor=2.0)
787
+ page = ctx.new_page()
788
+ page.goto(pathlib.Path(state.out_html).resolve().as_uri())
789
+ page.wait_for_load_state("networkidle")
790
+ page.screenshot(path=shot_png_path, full_page=True)
791
+ ctx.close()
792
+ browser.close()
793
+
794
+ data_a = encode_image_to_data_url(state.image_path)
795
+ data_b = encode_image_to_data_url(shot_png_path)
796
+ messages = [
797
+ {"role": "system", "content": "Return the specified PLAIN-TEXT report exactly as instructed."},
798
+ {"role": "user", "content": [
799
+ {"type": "text", "text": SCORING_RUBRIC.strip()},
800
+ {"type": "image_url", "image_url":{"url": data_a}},
801
+ {"type": "image_url", "image_url":{"url": data_b}},
802
+ {"type": "text", "text": "HTML/CSS code used to produce image (B):\n" + state.html}
803
+ ]},
804
+ ]
805
+ resp = models.chat_complete_azure(state.vision_deployment, messages, 0.0, state.judge_tokens)
806
+ state.scores = parse_text_report(resp)
807
+ state.messages.append(f"Stage3: Scoring done (Iter {state.current_iteration}).")
808
+
809
+ min_score = min(int(state.scores["scores"][k]) for k in _SCORE_KEYS)
810
+ if min_score >= state.refine_threshold:
811
+ state.stop_refinement = True
812
+ return state
813
+
814
+ def node_refine_loop(state: CodeRefineState) -> CodeRefineState:
815
+ if state.stop_refinement or state.current_iteration >= state.refine_max_iters:
816
+ state.messages.append("Refinement loop ended.")
817
+ return state
818
+
819
+ state.current_iteration += 1
820
+ state.html = refine_with_feedback(
821
+ vision_deployment=state.vision_deployment,
822
+ wireframe_image=state.image_path,
823
+ current_html=state.html,
824
+ feedback=state.scores.get("feedback",""),
825
+ css_patch=state.scores.get("css_patch",""),
826
+ html_edits=state.scores.get("html_edits",""),
827
+ regenerate_prompt=state.scores.get("regenerate_prompt",""),
828
+ temperature=state.temp,
829
+ max_tokens=state.code_tokens
830
+ )
831
+ versioned_path = pathlib.Path(state.out_html).with_name(pathlib.Path(state.out_html).stem + f"_v{state.current_iteration}" + pathlib.Path(state.out_html).suffix)
832
+ with open(versioned_path, "w", encoding="utf-8") as f: f.write(state.html)
833
+ state.out_html = str(versioned_path) # Update state to use new path for next scoring
834
+ state.messages.append(f"Saved refined HTML v{state.current_iteration} -> {versioned_path}")
835
+ return state
836
+
837
+ def decide_next(state: CodeRefineState) -> str:
838
+ if not state.stop_refinement and state.current_iteration < state.refine_max_iters:
839
+ return "refine_loop"
840
+ return "end"
841
+
842
+ def decide_if_assets_found(state: CodeRefineState) -> str:
843
+ """Checks the asset_plan to decide the next step."""
844
+ if state.asset_plan:
845
+ return "find_assets"
846
+ else:
847
+ # No assets found, skip finding and patching
848
+ return "patch_html"
849
+
850
+ def build_azure_vlm_graph():
851
+ workflow = StateGraph(CodeRefineState)
852
+ workflow.add_node("stage0", node_stage0)
853
+ workflow.add_node("stage1", node_stage1)
854
+ workflow.add_node("stage2", node_stage2)
855
+ workflow.add_node("plan_assets_from_html", node_plan_assets_from_html)
856
+ workflow.add_node("stage1_find_assets", node_stage1_find_assets)
857
+ workflow.add_node("patch_html", node_patch_html_with_assets)
858
+ workflow.add_node("stage3_score", node_stage3_score)
859
+ workflow.add_node("refine_loop", node_refine_loop)
860
+
861
+ workflow.set_entry_point("stage0")
862
+ workflow.add_edge("stage0", "stage1")
863
+ workflow.add_edge("stage1", "stage2")
864
+ workflow.add_edge("stage2", "plan_assets_from_html")
865
+
866
+ workflow.add_conditional_edges(
867
+ "plan_assets_from_html",
868
+ decide_if_assets_found,
869
+ {
870
+ "find_assets": "stage1_find_assets",
871
+ "patch_html": "patch_html" # Skip to patching (which will also skip)
872
+ }
873
+ )
874
+
875
+ workflow.add_edge("stage1_find_assets", "patch_html")
876
+ workflow.add_edge("patch_html", "stage3_score")
877
+
878
+ # Original refinement loop
879
+ workflow.add_edge("stage3_score", "refine_loop")
880
+ workflow.add_conditional_edges("refine_loop", decide_next, {"refine_loop": "stage3_score", "end": END})
881
+
882
+ return workflow.compile(checkpointer=MemorySaver())
883
+
884
+ # --- Compile the graph ---
885
+ azure_vlm_app = build_azure_vlm_graph()
886
+
887
+ # ----------------------------------------------------------------------
888
+ # ----------------------------------------------------------------------
889
+ # ## SECTION 5: MAIN "BRAIN" AGENT (Command Center)
890
+ #
891
+ # This new agent uses the local Qwen VLM to decide which
892
+ # pipeline to run using standard conditional routing.
893
+ # ----------------------------------------------------------------------
894
+ # ----------------------------------------------------------------------
895
+
896
+ class BrainState(TypedDict):
897
+ messages: List[Dict[str, Any]]
898
+ cli_args: argparse.Namespace
899
+
900
+ # New fields for routing
901
+ next_task: Optional[str] = None
902
+ task_args: Optional[Dict[str, Any]] = None
903
+ task_result: Optional[str] = None
904
+
905
+
906
+ # --- Pipeline Functions ---
907
+ def helper_run_azure_vlm_pipeline(image_path: str) -> str:
908
+ """
909
+ Use this tool to generate a new HTML webpage from a wireframe image.
910
+ This runs the full Azure VLM pipeline, including asset finding and refinement.
911
+
912
+ Args:
913
+ image_path (str): The file path to the input wireframe image.
914
+ """
915
+ print(f"--- BRAIN: Invoking Azure VLM Pipeline for {image_path} ---")
916
+ try:
917
+ # --- Define defaults INSIDE the function ---
918
+ default_out_html = "Outputs/default_vlm_output.html"
919
+ default_out_brief = "Outputs/default_vlm_brief.txt"
920
+ default_out_reldesc = "Outputs/default_vlm_reldesc.txt"
921
+
922
+ # --- Hardcode model choices ---
923
+ vision_deployment = "gpt-4.1-mini" # Using default
924
+ text_deployment = "gpt-4.1-mini" # Using default
925
+
926
+ pathlib.Path(default_out_html).parent.mkdir(parents=True, exist_ok=True)
927
+
928
+ state = CodeRefineState(
929
+ image_path=image_path,
930
+ out_rel_desc=default_out_reldesc,
931
+ out_brief=default_out_brief,
932
+ out_html=default_out_html,
933
+ vision_deployment=vision_deployment, # Hardcoded
934
+ text_deployment=text_deployment, # Hardcoded
935
+ # Hardcode pipeline defaults
936
+ reldesc_tokens=700,
937
+ brief_tokens=1100,
938
+ code_tokens=2200,
939
+ judge_tokens=900,
940
+ temp=0.12,
941
+ refine_max_iters=3,
942
+ refine_threshold=8,
943
+ shot_width=1536,
944
+ shot_height=900
945
+ )
946
+
947
+ run_id = f"wireframe-{uuid4()}"
948
+ config = {"configurable": {"thread_id": run_id}}
949
+ result = azure_vlm_app.invoke(state, config=config)
950
+
951
+ final_path = result.get('out_html', default_out_html)
952
+ return json.dumps({
953
+ "status": "success",
954
+ "message": "Azure VLM pipeline completed.",
955
+ "final_html_path": final_path,
956
+ "messages": result.get("messages", [])
957
+ })
958
+ except Exception as e:
959
+ print(f"Error in Azure VLM helper: {e}")
960
+ return json.dumps({"status": "error", "message": str(e)})
961
+
962
+ def helper_run_code_editor(html_path: str, edit_request: str) -> str:
963
+ """
964
+ Use this tool to edit an existing HTML file based on a user's text request.
965
+ Args:
966
+ html_path (str): The file path to the HTML file you want to edit.
967
+ edit_request (str): The user's instruction (e.g., "Make the h1 tag blue").
968
+ """
969
+ print(f"--- BRAIN: Invoking Code Editor for {html_path} ---")
970
+ try:
971
+ # --- Hardcode model choice ---
972
+ model_choice = "qwen-local"
973
+
974
+ p = pathlib.Path(html_path)
975
+ output_path = str(p.parent / f"{p.stem}_edited.html")
976
+
977
+ with open(html_path, "r", encoding="utf-8") as f:
978
+ original_html = f.read()
979
+
980
+ initial_state = {
981
+ "html_code": original_html,
982
+ "user_request": edit_request,
983
+ "model_choice": model_choice, # Hardcoded
984
+ "messages": []
985
+ }
986
+
987
+ config = {"configurable": {"thread_id": f"edit-thread-{uuid4()}"}}
988
+ final_state = edit_agent_app.invoke(initial_state, config=config)
989
+
990
+ new_html_code = final_state['html_code']
991
+
992
+ pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
993
+ with open(output_path, "w", encoding="utf-8") as f:
994
+ f.write(new_html_code)
995
+
996
+ return json.dumps({
997
+ "status": "success",
998
+ "message": "Code editing complete.",
999
+ "final_html_path": output_path,
1000
+ "messages": final_state.get("messages", [])
1001
+ })
1002
+ except Exception as e:
1003
+ print(f"Error in Code Editor helper: {e}")
1004
+ return json.dumps({"status": "error", "message": str(e)})
1005
+
1006
+ def helper_run_asset_search(description: str, width: int = 512, height: int = 512) -> str:
1007
+ """
1008
+ Use this tool to find or generate a single image asset.
1009
+ ... (docstring args) ...
1010
+ """
1011
+ print(f"--- BRAIN: Invoking Asset Search for '{description}' ---")
1012
+ try:
1013
+ result = asset_agent_app.invoke({"instructions": description, "bounding_box": (width, height)})
1014
+
1015
+ if final_path := result.get("final_asset_path"):
1016
+ return json.dumps({
1017
+ "status": "success",
1018
+ "message": "Asset found/generated.",
1019
+ "asset_path": final_path
1020
+ })
1021
+ else:
1022
+ return json.dumps({"status": "error", "message": "Asset process failed."})
1023
+ except Exception as e:
1024
+ print(f"Error in Asset Search helper: {e}")
1025
+ return json.dumps({"status": "error", "message": str(e)})
1026
+
1027
+ # --- List of all helper functions for the Brain ---
1028
+ helper_functions = [
1029
+ helper_run_azure_vlm_pipeline,
1030
+ helper_run_code_editor,
1031
+ helper_run_asset_search,
1032
+ ]
1033
+
1034
+ # --- Brain Agent Definition (Router) ---
1035
+ class QwenRouterAgent:
1036
+ def __init__(self, model_manager, functions, system_prompt=""):
1037
+ self.model = model_manager
1038
+ self.functions = {f.__name__: f for f in functions}
1039
+ self.system_prompt = system_prompt
1040
+
1041
+ def __call__(self, state: BrainState):
1042
+ messages = state['messages']
1043
+
1044
+ qwen_messages = []
1045
+ if self.system_prompt:
1046
+ qwen_messages.append({"role": "system", "content": [{"type": "text", "text": self.system_prompt}]})
1047
+
1048
+ for msg in messages:
1049
+ qwen_messages.append({
1050
+ "role": msg['role'],
1051
+ "content": [{"type": "text", "text": msg['content']}]
1052
+ })
1053
+
1054
+ last_user_message = messages[-1]['content']
1055
+
1056
+ # *** UPDATED PROMPT ***
1057
+ router_prompt = f"""
1058
+ You are a "command center" agent. Your job is to route a user's request to the correct function
1059
+ by providing a single, valid JSON object.
1060
+
1061
+ **Function Schemas:**
1062
+
1063
+ 1. **Generate a new page from an image:**
1064
+ {{
1065
+ "function_name": "helper_run_azure_vlm_pipeline",
1066
+ "function_args": {{
1067
+ "image_path": "<string: path to the input image>"
1068
+ }}
1069
+ }}
1070
+
1071
+ 2. **Edit an existing HTML file:**
1072
+ {{
1073
+ "function_name": "helper_run_code_editor",
1074
+ "function_args": {{
1075
+ "html_path": "<string: path to the HTML file to edit>",
1076
+ "edit_request": "<string: the user's full instruction for changes>"
1077
+ }}
1078
+ }}
1079
+
1080
+ 3. **Find or generate a single image asset:**
1081
+ {{
1082
+ "function_name": "helper_run_asset_search",
1083
+ "function_args": {{
1084
+ "description": "<string: a detailed description of the image>",
1085
+ "width": "<int, default: 512>",
1086
+ "height": "<int, default: 512>"
1087
+ }}
1088
+ }}
1089
+
1090
+ 4. **No function needed:**
1091
+ {{
1092
+ "function_name": "end",
1093
+ "function_args": {{}}
1094
+ }}
1095
+
1096
+ **Instructions:**
1097
+ 1. Analyze the "User Request" and "Context".
1098
+ 2. Choose the *one* function from the schemas above that best matches the request.
1099
+ 3. **CRITICAL RULE: `helper_run_azure_vlm_pipeline` *already includes* asset searching. If the user asks to "generate a page" or "build a wireframe", you MUST choose *only* that function. DO NOT choose `helper_run_asset_search` separately.**
1100
+ 4. Populate the `function_args` with the *exact argument names* from the schema.
1101
+ 5. Respond with ONLY the valid JSON object for your chosen function.
1102
+
1103
+ ---
1104
+ **User Request:** "{last_user_message}"
1105
+
1106
+ **Context:**
1107
+ """
1108
+ cli_args = state['cli_args']
1109
+ if cli_args.image:
1110
+ router_prompt += f"- An image path was provided: {cli_args.image}\n"
1111
+ if cli_args.html:
1112
+ router_prompt += f"- An HTML path was provided: {cli_args.html}\n"
1113
+ router_prompt += "\n**Your JSON Response:**"
1114
+
1115
+
1116
+ print("--- BRAIN: Routing prompt ---")
1117
+
1118
+ vlm_response = self.model.chat_llm(router_prompt)
1119
+
1120
+ print(f"--- BRAIN: VLM Response ---\n{vlm_response}\n-------------------------")
1121
+
1122
+ try:
1123
+ call_json = json.loads(vlm_response[vlm_response.find("{"):vlm_response.rfind("}")+1])
1124
+ func_name = call_json.get("function_name")
1125
+ func_args = call_json.get("function_args", {})
1126
+
1127
+ if func_name == "end" or func_name not in self.functions:
1128
+ print("--- BRAIN: No operation selected. Ending task. ---")
1129
+ return {"next_task": "end", "task_result": "No operation selected."}
1130
+
1131
+ # We just fill in any missing values from the CLI context.
1132
+ if func_name == "helper_run_azure_vlm_pipeline":
1133
+ func_args['image_path'] = func_args.get('image_path', cli_args.image)
1134
+
1135
+ elif func_name == "helper_run_code_editor":
1136
+ func_args['html_path'] = func_args.get('html_path', cli_args.html)
1137
+ func_args['edit_request'] = func_args.get('edit_request', last_user_message)
1138
+
1139
+ elif func_name == "helper_run_asset_search":
1140
+ func_args['description'] = func_args.get('description', last_user_message)
1141
+ func_args['width'] = func_args.get('width', 512)
1142
+ func_args['height'] = func_args.get('height', 512)
1143
+
1144
+ return {
1145
+ "next_task": func_name,
1146
+ "task_args": func_args
1147
+ }
1148
+
1149
+ except Exception as e:
1150
+ print(f"--- BRAIN: Error parsing VLM response. Ending task. --- \n{e}")
1151
+ return {"next_task": "end", "task_result": f"Error parsing VLM response: {e}"}
1152
+
1153
+ # --- Graph Nodes ---
1154
+ def node_run_vlm_pipeline(state: BrainState) -> dict:
1155
+ print("---(Brain Graph) NODE: node_run_vlm_pipeline ---")
1156
+ args = state['task_args']
1157
+ result = helper_run_azure_vlm_pipeline(**args)
1158
+ return {"task_result": result, "messages": state['messages'] + [{"role": "assistant", "content": result}]}
1159
+
1160
+ def node_run_code_editor(state: BrainState) -> dict:
1161
+ print("---(Brain Graph) NODE: node_run_code_editor ---")
1162
+ args = state['task_args']
1163
+ result = helper_run_code_editor(**args)
1164
+ return {"task_result": result, "messages": state['messages'] + [{"role": "assistant", "content": result}]}
1165
+
1166
+ def node_run_asset_search(state: BrainState) -> dict:
1167
+ print("---(Brain Graph) NODE: node_run_asset_search ---")
1168
+ args = state['task_args']
1169
+ result = helper_run_asset_search(**args)
1170
+ return {"task_result": result, "messages": state['messages'] + [{"role": "assistant", "content": result}]}
1171
+
1172
+ # --- Router Function ---
1173
+ def brain_router(state: BrainState) -> str:
1174
+ """Routes to the correct node based on the 'next_task' state."""
1175
+ print(f"---(Brain Graph) ROUTER: Next task is '{state['next_task']}' ---")
1176
+ if state['next_task'] == "helper_run_azure_vlm_pipeline":
1177
+ return "run_vlm"
1178
+ elif state['next_task'] == "helper_run_code_editor":
1179
+ return "run_edit"
1180
+ elif state['next_task'] == "helper_run_asset_search":
1181
+ return "run_asset"
1182
+ else:
1183
+ return "end"
1184
+
1185
+ # --- Graph Builder ---
1186
+ def build_brain_graph():
1187
+
1188
+ brain_agent_node = QwenRouterAgent(models, helper_functions)
1189
+
1190
+ workflow = StateGraph(BrainState)
1191
+
1192
+ # Add nodes
1193
+ workflow.add_node("agent", brain_agent_node)
1194
+ workflow.add_node("run_vlm", node_run_vlm_pipeline)
1195
+ workflow.add_node("run_edit", node_run_code_editor)
1196
+ workflow.add_node("run_asset", node_run_asset_search)
1197
+
1198
+ workflow.set_entry_point("agent")
1199
+
1200
+ # Add conditional router
1201
+ workflow.add_conditional_edges(
1202
+ "agent",
1203
+ brain_router,
1204
+ {
1205
+ "run_vlm": "run_vlm",
1206
+ "run_edit": "run_edit",
1207
+ "run_asset": "run_asset",
1208
+ "end": END
1209
+ }
1210
+ )
1211
+
1212
+ # Add final edges
1213
+ workflow.add_edge("run_vlm", END)
1214
+ workflow.add_edge("run_edit", END)
1215
+ workflow.add_edge("run_asset", END)
1216
+
1217
+ return workflow.compile()
1218
+
1219
+ # --- Compile the brain graph ---
1220
+ brain_app = build_brain_graph()
1221
+
1222
+ # ----------------------------------------------------------------------
1223
+ # ----------------------------------------------------------------------
1224
+ # ## SECTION 6: CLI RUNNER (MAIN) - SIMPLIFIED FOR TESTING
1225
+ #
1226
+ # This entrypoint is modified to run a full test suite.
1227
+ # ----------------------------------------------------------------------
1228
+ # ----------------------------------------------------------------------
1229
+
1230
+ def run_test(test_name: str, initial_state: BrainState):
1231
+ """Helper function to run a single test case against the brain."""
1232
+ print("\n" + "="*70)
1233
+ print(f"--- STARTING TEST: {test_name} ---")
1234
+ print(f"--- User Prompt: {initial_state['messages'][0]['content']} ---")
1235
+
1236
+ run_id = f"brain-test-{uuid4()}"
1237
+ config = {"configurable": {"thread_id": run_id}}
1238
+
1239
+ # Invoke the brain
1240
+ final_state = brain_app.invoke(initial_state, config=config)
1241
+
1242
+ print("\n" + ("-"*25) + " BRAIN INVOCATION COMPLETE " + ("-"*25))
1243
+
1244
+ print("--- Final Task Result ---")
1245
+ task_result_str = final_state.get('task_result', "No result found (task may have ended early).")
1246
+
1247
+ if task_result_str:
1248
+ try:
1249
+ output_json = json.loads(task_result_str)
1250
+ print(json.dumps(output_json, indent=2))
1251
+ except json.JSONDecodeError:
1252
+ print(task_result_str)
1253
+ else:
1254
+ print("No final task result recorded.")
1255
+
1256
+ print("="*70)
1257
+
1258
+
1259
+ def main():
1260
+ """
1261
+ Main function modified to run a test suite for all 3 pipeline paths.
1262
+ """
1263
+
1264
+ # --- Test Data Setup (Simplified) ---
1265
+
1266
+ # For Test 1: Azure VLM Pipeline
1267
+ cli_args_vlm = argparse.Namespace(
1268
+ prompt="Generate a new HTML page from the wireframe at Images/2.png",
1269
+ image="Images/2.png",
1270
+ html=None
1271
+ )
1272
+
1273
+ # For Test 1.5: Azure internal asset search
1274
+ asset_test_image_path = "Images/asset_test.png"
1275
+ cli_args_vlm_assets = argparse.Namespace(
1276
+ prompt=f"Generate the page from {asset_test_image_path}, and find all image assets.",
1277
+ image=asset_test_image_path,
1278
+ html=None
1279
+ )
1280
+
1281
+ # For Test 2: Code Editor Pipeline
1282
+ test_edit_file = "Outputs/test_page_to_edit.html"
1283
+ cli_args_edit = argparse.Namespace(
1284
+ prompt="Change the title to 'Edited by Brain Agent' and make the h1 tag red.",
1285
+ image=None,
1286
+ html=test_edit_file
1287
+ )
1288
+
1289
+ # For Test 3: Asset Search Pipeline
1290
+ cli_args_asset = argparse.Namespace(
1291
+ prompt="Find a high-quality photo of a 'modern office desk with a laptop'",
1292
+ image=None,
1293
+ html=None
1294
+ )
1295
+
1296
+ # =================================================================
1297
+ # --- TEST 1: AZURE VLM PIPELINE (Image-to-Code) ---
1298
+ # =================================================================
1299
+ initial_state_vlm = {
1300
+ "messages": [{"role": "user", "content": cli_args_vlm.prompt}],
1301
+ "cli_args": cli_args_vlm
1302
+ }
1303
+ if not pathlib.Path(cli_args_vlm.image).exists():
1304
+ print(f"--- WARNING: Skipping Test 1 ---")
1305
+ print(f"Test image not found at: {cli_args_vlm.image}")
1306
+ else:
1307
+ run_test("Test 1: Azure VLM Pipeline (Image-to-Code)", initial_state_vlm)
1308
+
1309
+ # =================================================================
1310
+ # --- TEST 1.5: AZURE VLM PIPELINE (with Asset Search) ---
1311
+ # =================================================================
1312
+ initial_state_vlm_assets = {
1313
+ "messages": [{"role": "user", "content": cli_args_vlm_assets.prompt}],
1314
+ "cli_args": cli_args_vlm_assets
1315
+ }
1316
+ if not pathlib.Path(cli_args_vlm_assets.image).exists():
1317
+ print(f"\n--- WARNING: Skipping Test 1.5 ---")
1318
+ print(f"Asset test image not found at: {cli_args_vlm_assets.image}")
1319
+ print("This is the specific test for the internal asset-search pipeline.")
1320
+ else:
1321
+ run_test("Test 1.5: Azure VLM Pipeline (with Asset Search)", initial_state_vlm_assets)
1322
+
1323
+ # =================================================================
1324
+ # --- TEST 2: CODE EDITOR PIPELINE (Edit HTML) ---
1325
+ # =================================================================
1326
+ initial_state_edit = {
1327
+ "messages": [{"role": "user", "content": cli_args_edit.prompt}],
1328
+ "cli_args": cli_args_edit
1329
+ }
1330
+ run_test("Test 2: Code Editor Pipeline (qwen-local)", initial_state_edit)
1331
+
1332
+ # =================================================================
1333
+ # --- TEST 3: ASSET SEARCH PIPELINE (Find Image) ---
1334
+ # =================================================================
1335
+ initial_state_asset = {
1336
+ "messages": [{"role": "user", "content": cli_args_asset.prompt}],
1337
+ "cli_args": cli_args_asset
1338
+ }
1339
+ run_test("Test 3: Asset Search Pipeline", initial_state_asset)
1340
+
1341
+
1342
+ if __name__ == "__main__":
1343
+ # Ensure output directories exist
1344
+ pathlib.Path("Outputs/Assets").mkdir(parents=True, exist_ok=True)
1345
+ pathlib.Path("Outputs/").mkdir(parents=True, exist_ok=True)
1346
+
1347
+ # --- Create a dummy HTML file for Test 2 ---
1348
+ test_edit_file_path = "Outputs/test_page_to_edit.html"
1349
+ dummy_html = """
1350
+ <!DOCTYPE html>
1351
+ <html lang="en">
1352
+ <head>
1353
+ <meta charset="UTF-8">
1354
+ <title>Original Test Title</title>
1355
+ </head>
1356
+ <body>
1357
+ <h1>This is the original headline.</h1>
1358
+ <p>This is a paragraph.</p>
1359
+ </body>
1360
+ </html>
1361
+ """
1362
+ with open(test_edit_file_path, "w", encoding="utf-8") as f:
1363
+ f.write(dummy_html)
1364
+ print(f"Created dummy file for editing at: {test_edit_file_path}")
1365
+
1366
+ # Check for test image for Test 1
1367
+ if not pathlib.Path("Images/2.png").exists():
1368
+ print("\n--- WARNING ---")
1369
+ print("Test 1 (Azure VLM) requires an image at 'Images/2.png'.")
1370
+ print("Please add an image there or Test 1 will be skipped.")
1371
+ print("---------------\n")
1372
+
1373
+ # Check for Test 1.5 image
1374
+ if not pathlib.Path("Images/asset_test.png").exists():
1375
+ print("\n--- WARNING ---")
1376
+ print("Test 1.5 (Azure VLM + Asset Search) requires an image at 'Images/asset_test.png'.")
1377
+ print("This test *specifically* verifies the internal asset search.")
1378
+ print("Please add an image with clear image placeholders, or Test 1.5 will be skipped.")
1379
+ print("---------------\n")
1380
+
1381
+ main()
app (1).py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Production FastAPI Backend for Image2Code
3
+ Wraps the VLM agent into a REST API for HuggingFace Spaces deployment
4
+ """
5
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+ import uvicorn
9
+ import tempfile
10
+ import os
11
+ from pathlib import Path
12
+ from typing import Optional
13
+ from uuid import uuid4
14
+ import sys
15
+ import shutil
16
+
17
+ # Import agent components
18
+ from agent_azure_vlm_tools import ModelManager, azure_vlm_app, CodeRefineState
19
+
20
+ app = FastAPI(
21
+ title="Image2Code VLM API",
22
+ description="Backend API for Image2Code using Qwen2.5-VL-7B-Instruct",
23
+ version="1.0.0"
24
+ )
25
+
26
+ # CORS configuration for production
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=[
30
+ "http://localhost:5173", # Local dev
31
+ "http://localhost:3000",
32
+ "https://*.vercel.app", # Vercel deployments
33
+ # Add your specific production domain here
34
+ ],
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+
40
+ # Global model manager (singleton)
41
+ models = None
42
+
43
+
44
+ @app.on_event("startup")
45
+ async def startup_event():
46
+ """Load models on server startup"""
47
+ global models
48
+ print("🚀 Starting Image2Code VLM API...")
49
+ print("📦 Loading Qwen2.5-VL-7B-Instruct model... This may take 2-3 minutes.")
50
+ models = ModelManager()
51
+ print("✅ Models loaded successfully!")
52
+
53
+
54
+ @app.get("/")
55
+ async def root():
56
+ """Health check endpoint"""
57
+ return {
58
+ "status": "healthy",
59
+ "service": "Image2Code VLM API",
60
+ "model": "Qwen2.5-VL-7B-Instruct",
61
+ "version": "1.0.0",
62
+ "ready": models is not None
63
+ }
64
+
65
+
66
+ @app.get("/health")
67
+ async def health_check():
68
+ """Detailed health check"""
69
+ return {
70
+ "status": "healthy" if models is not None else "starting",
71
+ "models_loaded": models is not None,
72
+ "gpu_available": True, # Update based on actual check
73
+ }
74
+
75
+
76
+ @app.post("/api/generate")
77
+ async def generate_ui(
78
+ prompt: str = Form(..., description="Text prompt describing the UI to generate"),
79
+ image: UploadFile = File(..., description="Wireframe/screenshot image"),
80
+ ):
81
+ """
82
+ Generate UI code from a prompt and wireframe image using Qwen VLM
83
+
84
+ Args:
85
+ prompt: Description of the desired UI
86
+ image: Uploaded image file (PNG, JPG, etc.)
87
+
88
+ Returns:
89
+ JSON response with generated code, plan, and reasoning
90
+ """
91
+ if not models:
92
+ raise HTTPException(
93
+ status_code=503,
94
+ detail="Models are still loading. Please wait a moment and try again."
95
+ )
96
+
97
+ # Validate image type
98
+ if not image.content_type.startswith("image/"):
99
+ raise HTTPException(
100
+ status_code=400,
101
+ detail=f"Invalid file type: {image.content_type}. Please upload an image."
102
+ )
103
+
104
+ # Create temporary directory for this request
105
+ temp_dir = tempfile.mkdtemp()
106
+
107
+ try:
108
+ # Save uploaded image
109
+ image_path = os.path.join(temp_dir, "input_wireframe.png")
110
+ contents = await image.read()
111
+ with open(image_path, "wb") as f:
112
+ f.write(contents)
113
+
114
+ # Define output paths
115
+ out_html = os.path.join(temp_dir, "generated_output.html")
116
+ out_brief = os.path.join(temp_dir, "generated_brief.txt")
117
+ out_reldesc = os.path.join(temp_dir, "generated_reldesc.txt")
118
+
119
+ # Build state for agent
120
+ state = CodeRefineState(
121
+ image_path=image_path,
122
+ out_rel_desc=out_reldesc,
123
+ out_brief=out_brief,
124
+ out_html=out_html,
125
+ vision_deployment="gpt-4.1-mini",
126
+ text_deployment="gpt-4.1-mini",
127
+ reldesc_tokens=700,
128
+ brief_tokens=1100,
129
+ code_tokens=2200,
130
+ judge_tokens=900,
131
+ temp=0.12,
132
+ refine_max_iters=3,
133
+ refine_threshold=8,
134
+ shot_width=1536,
135
+ shot_height=900
136
+ )
137
+
138
+ # Run the agent
139
+ run_id = f"api-{uuid4()}"
140
+ config = {"configurable": {"thread_id": run_id}}
141
+
142
+ print(f"🎨 Processing wireframe with run_id: {run_id}")
143
+ agent_result = azure_vlm_app.invoke(state, config=config)
144
+ print(f"✅ Agent completed for run_id: {run_id}")
145
+
146
+ # Read generated code
147
+ generated_code = ""
148
+ if os.path.exists(out_html):
149
+ with open(out_html, 'r', encoding='utf-8') as f:
150
+ generated_code = f.read()
151
+
152
+ # Build response messages
153
+ messages = []
154
+
155
+ messages.append({
156
+ "id": f"msg-{uuid4()}",
157
+ "role": "assistant",
158
+ "variant": "accent",
159
+ "content": f"**Plan:** Generated UI from your wireframe using Qwen2.5-VL-7B-Instruct"
160
+ })
161
+
162
+ messages.append({
163
+ "id": f"msg-{uuid4()}",
164
+ "role": "assistant",
165
+ "variant": "subtle",
166
+ "content": """**Process:**
167
+ • Loaded and analyzed wireframe structure
168
+ • Identified UI components and layout
169
+ • Generated semantic HTML/CSS code
170
+ • Applied multi-stage refinement pipeline"""
171
+ })
172
+
173
+ if generated_code:
174
+ messages.append({
175
+ "id": f"msg-{uuid4()}",
176
+ "role": "assistant",
177
+ "variant": "accent",
178
+ "content": f"```html\n{generated_code}\n```"
179
+ })
180
+ else:
181
+ messages.append({
182
+ "id": f"msg-{uuid4()}",
183
+ "role": "assistant",
184
+ "variant": "subtle",
185
+ "content": "⚠️ Code generation completed but output file was empty."
186
+ })
187
+
188
+ return JSONResponse(content={
189
+ "messages": messages,
190
+ "status": {
191
+ "kind": "success",
192
+ "text": "UI code generated successfully",
193
+ "detail": f"Run ID: {run_id}"
194
+ },
195
+ "usedFallback": False
196
+ })
197
+
198
+ except Exception as e:
199
+ import traceback
200
+ error_details = traceback.format_exc()
201
+ print(f"❌ Error processing request: {error_details}")
202
+
203
+ raise HTTPException(
204
+ status_code=500,
205
+ detail=f"Error generating UI: {str(e)}"
206
+ )
207
+
208
+ finally:
209
+ # Cleanup temporary files
210
+ try:
211
+ shutil.rmtree(temp_dir)
212
+ except Exception as e:
213
+ print(f"⚠️ Failed to cleanup temp dir: {e}")
214
+
215
+
216
+ @app.post("/api/chat")
217
+ async def chat_with_vlm(
218
+ prompt: str = Form(..., description="Chat message/question"),
219
+ image: Optional[UploadFile] = File(None, description="Optional image for vision tasks"),
220
+ ):
221
+ """
222
+ Simple chat endpoint for VLM queries
223
+ """
224
+ if not models:
225
+ raise HTTPException(
226
+ status_code=503,
227
+ detail="Models not loaded yet"
228
+ )
229
+
230
+ return JSONResponse(content={
231
+ "response": "Chat endpoint is available but not fully implemented yet.",
232
+ "status": "success"
233
+ })
234
+
235
+
236
+ if __name__ == "__main__":
237
+ # For local testing
238
+ port = int(os.environ.get("PORT", 7860)) # HF Spaces uses 7860
239
+ uvicorn.run(
240
+ app,
241
+ host="0.0.0.0",
242
+ port=port,
243
+ timeout_keep_alive=300 # 5 minutes for long-running requests
244
+ )
config/config.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ headless = true
3
+ port = 7860
4
+ enableCORS = false
5
+ enableXsrfProtection = false
6
+
7
+ [browser]
8
+ gatherUsageStats = false
9
+
10
+ [theme]
11
+ primaryColor = "#667eea"
12
+ backgroundColor = "#ffffff"
13
+ secondaryBackgroundColor = "#f0f2f6"
14
+ textColor = "#262730"
15
+ font = "sans serif"
requirements (1).txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Dependencies
2
+ streamlit==1.51.0
3
+ fastapi==0.115.6
4
+ uvicorn[standard]==0.34.0
5
+ python-multipart==0.0.20
6
+ python-dotenv==1.0.1
7
+
8
+ # ML/AI Libraries
9
+ torch==2.6.0
10
+ torchvision==0.21.0
11
+ transformers==4.57.1
12
+ diffusers==0.32.2
13
+ accelerate==1.2.1
14
+ bitsandbytes==0.45.0
15
+ qwen-vl-utils==0.0.8
16
+
17
+ # LangGraph & Agent
18
+ langgraph==0.2.59
19
+ langgraph-checkpoint==2.0.10
20
+
21
+ # Azure OpenAI
22
+ openai==1.59.5
23
+
24
+ # Image Processing
25
+ Pillow==11.0.0
26
+ playwright==1.48.0
27
+
28
+ # Other utilities
29
+ requests==2.32.3
30
+ aiofiles==24.1.0
31
+ pydantic==2.10.4
32
+ pydantic-settings==2.7.0