English
Chinese
kaitongg commited on
Commit
300f260
·
verified ·
1 Parent(s): 118f2e2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +436 -0
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision.transforms as T
6
+ import json
7
+ import sentence_transformers
8
+ import os
9
+ import tempfile
10
+ import shutil
11
+ # Removed google.generativeai import as Gemini is excluded
12
+
13
+ # --- Model Loading (Consolidated from previous cells, excluding Gemini) ---
14
+
15
+ # Load Hugging Face Token (Needed for private repos or some operations)
16
+ # In Hugging Face Spaces, secrets are accessed via environment variables
17
+ # HF_TOKEN = os.environ.get('HF_TOKEN_WRITE') # Commented out - usually not needed for public model downloads
18
+
19
+ # Load Image Classification Model (from TTx28yjzHMgR)
20
+ try:
21
+ from huggingface_hub import hf_hub_download
22
+ import pickle
23
+ import timm # Ensure timm is imported if used
24
+
25
+ REPO_ID_IMG = "keerthikoganti/architecture-design-stages-compact-cnn"
26
+ pkl_path = hf_hub_download(repo_id=REPO_ID_IMG, filename="model_bundle.pkl")
27
+ with open(pkl_path, "rb") as f:
28
+ bundle = pickle.load(f)
29
+
30
+ architecture = bundle["architecture"]
31
+ num_classes = bundle["num_classes"]
32
+ class_names = bundle["class_names"]
33
+ state_dict = bundle["state_dict"]
34
+
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ model = timm.create_model(architecture, pretrained=False, num_classes=num_classes)
37
+ model.load_state_dict(state_dict)
38
+ model.eval().to(device)
39
+
40
+ TFM = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
41
+ print("Image Classification Model loaded successfully!")
42
+
43
+ except Exception as e:
44
+ print(f"Error loading Image Classification Model: {e}")
45
+ model = None
46
+ TFM = None
47
+ device = None
48
+ class_names = []
49
+
50
+
51
+ # Load Text Classification Model (from VysWLxnGItBa)
52
+ try:
53
+ from huggingface_hub import snapshot_download
54
+ from autogluon.tabular import TabularPredictor
55
+ import os # Ensure os is imported
56
+
57
+ repo_id_text = "kaitongg/my-autogluon-model"
58
+ download_dir = "downloaded_predictor"
59
+
60
+ # Download the entire model repository
61
+ print(f"Downloading text model files from {repo_id_text}...")
62
+ # Use HF_TOKEN if the repo is private: token=os.environ.get('HF_TOKEN_WRITE')
63
+
64
+ downloaded_path = snapshot_download(
65
+ repo_id=repo_id_text,
66
+ repo_type="model",
67
+ local_dir=download_dir,
68
+ local_dir_use_symlinks=False,
69
+ # token=HF_TOKEN # Uncomment if repo is private and HF_TOKEN is needed
70
+ )
71
+ print(f"Text model files downloaded to: {downloaded_path}")
72
+
73
+ # Load the predictor from the subdirectory 'autogluon_predictor'
74
+ predictor_path = os.path.join(downloaded_path, "autogluon_predictor")
75
+ loaded_predictor_from_hub = TabularPredictor.load(predictor_path)
76
+ print("Text Classification Model loaded successfully from Hugging Face Hub!")
77
+
78
+ except Exception as e:
79
+ print(f"Error loading Text Classification Model: {e}")
80
+ loaded_predictor_from_hub = None
81
+
82
+
83
+ # Load Sentence Transformer Model (from OJ9wke1CrK1S/global scope)
84
+ try:
85
+ embedding_model = sentence_transformers.SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
86
+ print("Sentence Transformer model loaded successfully!")
87
+ except Exception as e:
88
+ print(f"Error loading Sentence Transformer model: {e}")
89
+ embedding_model = None
90
+
91
+
92
+ # --- LLM Attitude Mapping (from 74905474) ---
93
+ llm_attitude_mapping = {
94
+ "brainstorm": "creative and encouraging",
95
+ "design_iteration": "constructive and detailed, focusing on improvements",
96
+ "design_optimization": "critical and focused on efficiency and refinement",
97
+ "final_review": "thorough and critical, evaluating completeness and adherence to requirements",
98
+ "random": "neutral and informative, perhaps suggesting a relevant stage",
99
+ }
100
+ print("LLM attitude mapping defined successfully!")
101
+
102
+
103
+ # --- Function Definitions (Consolidated from jKIkOPByaN3Z and OJ9wke1CrK1S) ---
104
+
105
+ # Define the specific text classification function (from OJ9wke1CrK1S/jKIkOPByaN3Z)
106
+ def perform_text_classification_and_format(text: str) -> tuple[dict, str]:
107
+ """
108
+ Performs text classification using the loaded predictor and embedding model,
109
+ and formats the results.
110
+
111
+ Args:
112
+ text: The input text string.
113
+
114
+ Returns:
115
+ A tuple containing:
116
+ - text_classification_probabilities (dict): Probabilities for each class.
117
+ - text_classification_formatted (str): Formatted string of classification results.
118
+ """
119
+ text_classification_probabilities = {"error": "No text provided or model not loaded"}
120
+ text_classification_formatted = "No text provided or model not loaded"
121
+ has_high_concept = "Cannot Determine" # Translated
122
+ confidence = 0.0
123
+
124
+ # Check if models are loaded before proceeding
125
+ if text and loaded_predictor_from_hub is not None and embedding_model is not None:
126
+ try:
127
+ # Encode the text using the embedding model
128
+ embeddings = embedding_model.encode(
129
+ [text],
130
+ batch_size=1,
131
+ show_progress_bar=False,
132
+ convert_to_numpy=True,
133
+ normalize_embeddings=False,
134
+ )
135
+
136
+ # Create a DataFrame with 'eX' column names from embeddings
137
+ n, d = embeddings.shape
138
+ text_df_processed = pd.DataFrame(embeddings, columns=[f"e{i}" for i in range(d)])
139
+
140
+ # Get text model prediction probabilities
141
+ text_proba_df = loaded_predictor_from_hub.predict_proba(text_df_processed)
142
+
143
+ # Assuming your predictor returns probabilities for class 0 and class 1
144
+ text_classification_probabilities = {
145
+ "No High Concept": float(text_proba_df.iloc[0]["0"]) if "0" in text_proba_df.columns else 0.0,
146
+ "High Concept": float(text_proba_df.iloc[0]["1"]) if "1" in text_proba_df.columns else 0.0,
147
+ }
148
+
149
+ # Determine the predicted class label (0 or 1) as a string
150
+ if not text_proba_df.empty and len(text_proba_df.columns) > 0:
151
+ predicted_text_label = str(loaded_predictor_from_hub.predict(text_df_processed).iloc[0])
152
+
153
+ # Correctly compare the predicted label as a string
154
+ if predicted_text_label == "1":
155
+ has_high_concept = "Yes" # Translated
156
+ confidence = text_classification_probabilities.get("High Concept", 0.0)
157
+ elif predicted_text_label == "0":
158
+ has_high_concept = "No" # Translated
159
+ confidence = text_classification_probabilities.get("No High Concept", 0.0)
160
+ else: # Handle unexpected labels
161
+ has_high_concept = f"Unknown Label: {predicted_text_label}" # Translated
162
+ confidence = 0.0
163
+ print(f"Warning: Predictor returned unexpected label: {predicted_text_label}")
164
+ else:
165
+ has_high_concept = "Cannot Determine (No Prediction Output)" # Translated
166
+
167
+
168
+ print(f"Text classified as having high concept: {has_high_concept}")
169
+ print(f"Text classification probabilities: {text_classification_probabilities}")
170
+
171
+ # Format the text classification results for display
172
+ text_classification_formatted = f"High Concept: {has_high_concept} (Confidence: {confidence:.2f})"
173
+
174
+ except Exception as e:
175
+ print(f"Error during text classification: {e}")
176
+ text_classification_probabilities = {"error": f"Text classification failed: {e}"}
177
+ text_classification_formatted = f"Text classification failed: {e}"
178
+ elif text:
179
+ print("Text predictor or embedding model not loaded for text classification.")
180
+ text_classification_probabilities = {"error": "Text predictor or embedding model not loaded"}
181
+ text_classification_formatted = "Text predictor or embedding model not loaded."
182
+ elif loaded_predictor_from_hub is None:
183
+ print("Text predictor model not loaded for text classification.")
184
+ text_classification_probabilities = {"error": "Text predictor model not loaded"}
185
+ text_classification_formatted = "Text predictor model not loaded."
186
+ else: # text is None or empty
187
+ text_classification_probabilities = {"info": "No text provided"}
188
+ text_classification_formatted = "No text provided"
189
+
190
+
191
+ return text_classification_probabilities, text_classification_formatted
192
+
193
+ print("perform_text_classification_and_format function defined.")
194
+
195
+
196
+ # Define the combined classification function (from jKIkOPByaN3Z)
197
+ # This function calls perform_text_classification_and_format defined above
198
+ def perform_classification_and_format(image: Image.Image, text: str) -> tuple[dict, dict, str]:
199
+ """
200
+ Performs image and text classification and formats the results.
201
+ Calls perform_text_classification_and_format for text classification.
202
+
203
+ Args:
204
+ image: The input PIL Image.
205
+ text: The input text string.
206
+
207
+ Returns:
208
+ A tuple containing:
209
+ - image_classification_results (dict): Probabilities for image classes.
210
+ - text_classification_probabilities (dict): Probabilities for text classes.
211
+ - text_classification_formatted (str): Formatted string of text classification results.
212
+ """
213
+ # Initialize output variables with default values
214
+ image_classification_results = {"error": "No image provided"}
215
+ # Text classification results will be obtained from perform_text_classification_and_format
216
+
217
+ # --- Process Image Input ---
218
+ design_stage = "unknown"
219
+ # Check if image model components are loaded
220
+ if image is not None and model is not None and TFM is not None and device is not None and class_names:
221
+ try:
222
+ # Apply the transformation
223
+ img_tensor = TFM(image).unsqueeze(0).to(device)
224
+
225
+ # Get the image model output
226
+ with torch.no_grad():
227
+ img_output = model(img_tensor)
228
+
229
+ # Get probabilities and predict the design stage
230
+ img_probabilities = torch.softmax(img_output, dim=1)[0]
231
+ predicted_class_index = torch.argmax(img_probabilities).item()
232
+ design_stage = class_names[predicted_class_index]
233
+
234
+ # Create a dictionary of class names and probabilities for Gradio Label output
235
+ image_classification_results = {class_names[i]: float(img_probabilities[i]) for i in range(len(class_names))}
236
+
237
+ print(f"Image classified as: {design_stage}")
238
+ print(f"Image classification probabilities: {image_classification_results}")
239
+
240
+ except Exception as e:
241
+ print(f"Error processing image: {e}")
242
+ design_stage = "error during classification"
243
+ image_classification_results = {"error": f"Image classification failed: {e}"}
244
+ elif image is not None:
245
+ print("Image model components not loaded.")
246
+ design_stage = "model_not_loaded"
247
+ image_classification_results = {"error": "Image model or components not loaded"}
248
+ else: # image is None
249
+ print("No image provided for image classification.")
250
+ image_classification_results = {"info": "No image provided"}
251
+ design_stage = "no_image"
252
+
253
+
254
+ # --- Process Text Input using the dedicated function ---
255
+ # perform_text_classification_and_format is defined above and returns (probabilities_dict, formatted_string)
256
+ text_classification_probabilities, text_classification_formatted = perform_text_classification_and_format(text)
257
+ print(f"Text classification formatted result: {text_classification_formatted}")
258
+ print(f"Text classification raw probabilities: {text_classification_probabilities}")
259
+
260
+
261
+ # Return image classification probabilities (dict), text classification probabilities (dict), and formatted text classification string
262
+ return image_classification_results, text_classification_probabilities, text_classification_formatted
263
+
264
+ print("perform_classification_and_format function defined.")
265
+
266
+
267
+ # Define a function to generate the prompt based on classification results and text (from jKIkOPByaN3Z)
268
+ def generate_prompt_only(image_classification_results: dict, text_classification_probabilities: dict, text: str) -> str:
269
+ """
270
+ Generates a prompt for the LLM based on image and text classification results.
271
+
272
+ Args:
273
+ image_classification_results: Dictionary of image class probabilities.
274
+ text_classification_probabilities: Dictionary of text class probabilities.
275
+ text: The original input text string.
276
+
277
+ Returns:
278
+ A string containing the generated prompt for the LLM.
279
+ """
280
+ # Extract design stage from image classification results
281
+ design_stage = "unknown"
282
+ if image_classification_results and "error" not in image_classification_results and "info" not in image_classification_results:
283
+ try:
284
+ # Find the class with the highest probability, excluding error/info keys
285
+ valid_results = {k: v for k, v in image_classification_results.items() if k not in ["error", "info"]}
286
+ if valid_results:
287
+ design_stage = max(valid_results, key=valid_results.get)
288
+ else:
289
+ design_stage = "unknown" # Fallback if no valid results
290
+ except Exception:
291
+ design_stage = "unknown"
292
+ elif "info" in image_classification_results:
293
+ design_stage = "no_image" # Special case if no image was provided
294
+ elif "error" in image_classification_results:
295
+ design_stage = "image_classification_failed" # Special case if image classification failed
296
+
297
+
298
+ # Extract high concept status from text classification probabilities
299
+ has_high_concept = "Cannot Determine" # Translated
300
+ if text_classification_probabilities and "error" not in text_classification_probabilities and "info" not in text_classification_probabilities:
301
+ try:
302
+ # Determine has_high_concept based on which probability is higher
303
+ high_concept_prob = text_classification_probabilities.get("High Concept", 0.0)
304
+ no_high_concept_prob = text_classification_probabilities.get("No High Concept", 0.0)
305
+ if high_concept_prob > no_high_concept_prob:
306
+ has_high_concept = "Yes" # Translated
307
+ else:
308
+ has_high_concept = "No" # Translated
309
+ except Exception:
310
+ has_high_concept = "Cannot Determine" # Translated
311
+ elif "info" in text_classification_probabilities:
312
+ has_high_concept = "no_text" # Special case if no text was provided
313
+ elif "error" in text_classification_probabilities:
314
+ has_high_concept = "text_classification_failed" # Special case if text classification failed
315
+
316
+
317
+ # --- Generate Dynamic Prompt for LLM ---
318
+ # Note: The prompt is still generated, but the LLM interaction part is removed.
319
+ # The prompt structure is based on previous requirements.
320
+ # Use a default attitude if design_stage or has_high_concept are special error/info states
321
+ if design_stage in ["unknown", "no_image", "image_classification_failed"] or has_high_concept in ["Cannot Determine", "no_text", "text_classification_failed"]: # Translated
322
+ llm_attitude = llm_attitude_mapping.get("random", "neutral and informative") # Use random or a default neutral attitude
323
+ else:
324
+ llm_attitude = llm_attitude_mapping.get(design_stage, llm_attitude_mapping.get("random", "neutral and informative"))
325
+
326
+
327
+ # Translated prompt components
328
+ prompt = f"""User is a low-level architecture student struggling with critical architectural reviews. You are an abstract architecture critique interpreter. Your response must be in English.
329
+ Given that the user is in the {design_stage} design stage, your attitude should be {llm_attitude}.
330
+ Given that the user input result (Yes/No) contains abstract architectural concepts: {has_high_concept}.
331
+ If the user input contains abstract architectural concepts, you need to explain the abstract concept to the user and then provide actionable advice. If not, you can directly provide actionable advice.
332
+ User input text content: {text} You need to explain abstract concepts to the user using language that a child can understand, provide examples from daily life, and offer actionable advice.
333
+ """ # Use full text input
334
+
335
+ return prompt
336
+
337
+ print("generate_prompt_only function defined.")
338
+
339
+
340
+ # Removed generate_feedback_from_prompt function as Gemini LLM is excluded
341
+
342
+
343
+ # --- Create Gradio Interface (Consolidated from jKIkOPByaN3Z, excluding LLM feedback parts) ---
344
+ # Define example inputs for the Gradio interface
345
+ examples = [
346
+ # Example 1: Brainstorm stage, text with high concept
347
+ ["https://balancedarchitecture.com/wp-content/uploads/2021/11/EXISTING-FIRST-FLOOR-PRES-scaled-e1635965923983.jpg", "Exploring spatial relationships and material palettes."],
348
+ # Example 2: Design Iteration stage, text without high concept
349
+ ["https://cdn.prod.website-files.com/5894a32730554b620f7bf36d/5e848c2d622e7abe1ad48504_5e01ce9f0d272014d0353cd1_Things-You-Need-to-Organize-a-3D-Rendering-Architectural-Project-EASY-RENDER.jpeg", "The window size is too small."],
350
+ # Example 3: Final Review stage, text with some concept
351
+ ["https://architectelevator.com/assets/img/bilbao_sketch.png", "The facade expresses the building's relationship with the urban context."],
352
+ ]
353
+
354
+ with gr.Blocks() as demo_step_by_step:
355
+ gr.Markdown("# Architecture Feedback Generator (Classification & Prompt Only)") # Translated
356
+ gr.Markdown("""
357
+ Upload an architectural image and provide a text description or question to see classification results and the generated prompt.
358
+ (LLM feedback generation is excluded from this version).
359
+ """)
360
+
361
+ with gr.Row():
362
+ image_input = gr.Image(type="pil", label="Upload Architectural Image") # Translated
363
+ text_input = gr.Textbox(label="Enter Text Description or Question") # Translated
364
+
365
+ classify_and_prompt_button = gr.Button("Perform Classification & Generate Prompt") # Translated
366
+
367
+
368
+ with gr.Row():
369
+ # Assuming class_names is loaded, otherwise provide a default like 5
370
+ image_output_label = gr.Label(num_top_classes=len(class_names) if 'class_names' in globals() and class_names else 5, label="Image Classification Results") # Translated
371
+ text_output_textbox = gr.Textbox(label="Text Classification Results") # Translated
372
+
373
+ # Use gr.State components to store intermediate results needed for subsequent steps
374
+ text_classification_probabilities_state = gr.State()
375
+
376
+ prompt_output_textbox = gr.Textbox(label="Generated Prompt for LLM", interactive=True) # Translated - Allow user to inspect/edit prompt
377
+
378
+ # Removed LLM feedback output component and button
379
+
380
+
381
+ # Define the event chain
382
+ # 1. When classify_and_prompt_button is clicked, perform classification and format results
383
+ # perform_classification_and_format returns:
384
+ # (image_classification_results, text_classification_probabilities, text_classification_formatted)
385
+ # Map outputs to image_output_label, text_classification_probabilities_state, and text_output_textbox
386
+ classification_outputs = classify_and_prompt_button.click(
387
+ fn=perform_classification_and_format,
388
+ inputs=[image_input, text_input],
389
+ outputs=[image_output_label, text_classification_probabilities_state, text_output_textbox], # Corrected outputs list
390
+ # queue=False # Consider if queuing is needed
391
+ )
392
+
393
+ # 2. Then, use the outputs of the first step to generate and display the prompt
394
+ # Trigger when any of the classification outputs are updated. Use the State component for text probs.
395
+ classification_outputs[2].then( # Trigger when text_output_textbox (output[2]) is updated
396
+ fn=generate_prompt_only,
397
+ inputs=[
398
+ classification_outputs[0], # References the output component holding img_res
399
+ classification_outputs[1], # References the State component holding txt_prob
400
+ text_input # Original text input component
401
+ ],
402
+ outputs=prompt_output_textbox,
403
+ # queue=False # Consider if queuing is needed
404
+ )
405
+
406
+ # Removed LLM feedback generation button click event
407
+
408
+ # Add examples - Examples should trigger the classification -> prompt generation chain
409
+ # This requires a function that performs both steps for a given example input.
410
+ def generate_full_chain_output_step_by_step(img, txt):
411
+ # Step 1: Classification
412
+ img_res, txt_prob, txt_fmt = perform_classification_and_format(img, txt)
413
+ # Step 2: Prompt Generation
414
+ prompt = generate_prompt_only(img_res, txt_prob, txt)
415
+ # Return the outputs expected by gr.Examples outputs
416
+ # The outputs for examples are: image_output_label, text_output_textbox, prompt_output_textbox.
417
+ # Need to return img_res, txt_fmt, prompt in that order.
418
+ return img_res, txt_fmt, prompt
419
+
420
+ # Note: The examples outputs need to match the outputs of the fn.
421
+ # The outputs from generate_full_chain_output_step_by_step are img_res, txt_fmt, prompt.
422
+ # The Gradio outputs defined are image_output_label, text_output_textbox, prompt_output_textbox.
423
+ # The order should match.
424
+ gr.Examples(
425
+ examples=examples,
426
+ inputs=[image_input, text_input],
427
+ # Outputs to update for examples: Image Classification, Text Classification, Prompt
428
+ outputs=[image_output_label, text_output_textbox, prompt_output_textbox],
429
+ fn=generate_full_chain_output_step_by_step,
430
+ cache_examples=False, # Set to False to re-run the function on example click
431
+ )
432
+
433
+
434
+ # Launch the interface
435
+ # if __name__ == "__main__": # Remove this block for deployment to Spaces
436
+ demo_step_by_step.launch()