| import gradio as gr |
| import pandas as pd |
| from PIL import Image |
| import torch |
| import torchvision.transforms as T |
| import json |
| import sentence_transformers |
| import os |
| import tempfile |
| import shutil |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| try: |
| from huggingface_hub import hf_hub_download |
| import pickle |
| import timm |
|
|
| REPO_ID_IMG = "keerthikoganti/architecture-design-stages-compact-cnn" |
| pkl_path = hf_hub_download(repo_id=REPO_ID_IMG, filename="model_bundle.pkl") |
| with open(pkl_path, "rb") as f: |
| bundle = pickle.load(f) |
|
|
| architecture = bundle["architecture"] |
| num_classes = bundle["num_classes"] |
| class_names = bundle["class_names"] |
| state_dict = bundle["state_dict"] |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = timm.create_model(architecture, pretrained=False, num_classes=num_classes) |
| model.load_state_dict(state_dict) |
| model.eval().to(device) |
|
|
| TFM = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]) |
| print("Image Classification Model loaded successfully!") |
|
|
| except Exception as e: |
| print(f"Error loading Image Classification Model: {e}") |
| model = None |
| TFM = None |
| device = None |
| class_names = [] |
|
|
|
|
| |
| try: |
| from huggingface_hub import snapshot_download |
| from autogluon.tabular import TabularPredictor |
| import os |
|
|
| repo_id_text = "kaitongg/my-autogluon-model" |
| download_dir = "downloaded_predictor" |
|
|
| |
| print(f"Downloading text model files from {repo_id_text}...") |
| |
|
|
| downloaded_path = snapshot_download( |
| repo_id=repo_id_text, |
| repo_type="model", |
| local_dir=download_dir, |
| local_dir_use_symlinks=False, |
| |
| ) |
| print(f"Text model files downloaded to: {downloaded_path}") |
|
|
| |
| predictor_path = os.path.join(downloaded_path, "autogluon_predictor") |
| loaded_predictor_from_hub = TabularPredictor.load(predictor_path) |
| print("Text Classification Model loaded successfully from Hugging Face Hub!") |
|
|
| except Exception as e: |
| print(f"Error loading Text Classification Model: {e}") |
| loaded_predictor_from_hub = None |
|
|
|
|
| |
| try: |
| embedding_model = sentence_transformers.SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
| print("Sentence Transformer model loaded successfully!") |
| except Exception as e: |
| print(f"Error loading Sentence Transformer model: {e}") |
| embedding_model = None |
|
|
|
|
| |
| llm_attitude_mapping = { |
| "brainstorm": "creative and encouraging", |
| "design_iteration": "constructive and detailed, focusing on improvements", |
| "design_optimization": "critical and focused on efficiency and refinement", |
| "final_review": "thorough and critical, evaluating completeness and adherence to requirements", |
| "random": "neutral and informative, perhaps suggesting a relevant stage", |
| } |
| print("LLM attitude mapping defined successfully!") |
|
|
|
|
| |
|
|
| |
| def perform_text_classification_and_format(text: str) -> tuple[dict, str]: |
| """ |
| Performs text classification using the loaded predictor and embedding model, |
| and formats the results. |
| |
| Args: |
| text: The input text string. |
| |
| Returns: |
| A tuple containing: |
| - text_classification_probabilities (dict): Probabilities for each class. |
| - text_classification_formatted (str): Formatted string of classification results. |
| """ |
| text_classification_probabilities = {"error": "No text provided or model not loaded"} |
| text_classification_formatted = "No text provided or model not loaded" |
| has_high_concept = "Cannot Determine" |
| confidence = 0.0 |
|
|
| |
| if text and loaded_predictor_from_hub is not None and embedding_model is not None: |
| try: |
| |
| embeddings = embedding_model.encode( |
| [text], |
| batch_size=1, |
| show_progress_bar=False, |
| convert_to_numpy=True, |
| normalize_embeddings=False, |
| ) |
|
|
| |
| n, d = embeddings.shape |
| text_df_processed = pd.DataFrame(embeddings, columns=[f"e{i}" for i in range(d)]) |
|
|
| |
| text_proba_df = loaded_predictor_from_hub.predict_proba(text_df_processed) |
|
|
| |
| text_classification_probabilities = { |
| "No High Concept": float(text_proba_df.iloc[0]["0"]) if "0" in text_proba_df.columns else 0.0, |
| "High Concept": float(text_proba_df.iloc[0]["1"]) if "1" in text_proba_df.columns else 0.0, |
| } |
|
|
| |
| if not text_proba_df.empty and len(text_proba_df.columns) > 0: |
| predicted_text_label = str(loaded_predictor_from_hub.predict(text_df_processed).iloc[0]) |
|
|
| |
| if predicted_text_label == "1": |
| has_high_concept = "Yes" |
| confidence = text_classification_probabilities.get("High Concept", 0.0) |
| elif predicted_text_label == "0": |
| has_high_concept = "No" |
| confidence = text_classification_probabilities.get("No High Concept", 0.0) |
| else: |
| has_high_concept = f"Unknown Label: {predicted_text_label}" |
| confidence = 0.0 |
| print(f"Warning: Predictor returned unexpected label: {predicted_text_label}") |
| else: |
| has_high_concept = "Cannot Determine (No Prediction Output)" |
|
|
|
|
| print(f"Text classified as having high concept: {has_high_concept}") |
| print(f"Text classification probabilities: {text_classification_probabilities}") |
|
|
| |
| text_classification_formatted = f"High Concept: {has_high_concept} (Confidence: {confidence:.2f})" |
|
|
| except Exception as e: |
| print(f"Error during text classification: {e}") |
| text_classification_probabilities = {"error": f"Text classification failed: {e}"} |
| text_classification_formatted = f"Text classification failed: {e}" |
| elif text: |
| print("Text predictor or embedding model not loaded for text classification.") |
| text_classification_probabilities = {"error": "Text predictor or embedding model not loaded"} |
| text_classification_formatted = "Text predictor or embedding model not loaded." |
| elif loaded_predictor_from_hub is None: |
| print("Text predictor model not loaded for text classification.") |
| text_classification_probabilities = {"error": "Text predictor model not loaded"} |
| text_classification_formatted = "Text predictor model not loaded." |
| else: |
| text_classification_probabilities = {"info": "No text provided"} |
| text_classification_formatted = "No text provided" |
|
|
|
|
| return text_classification_probabilities, text_classification_formatted |
|
|
| print("perform_text_classification_and_format function defined.") |
|
|
|
|
| |
| |
| def perform_classification_and_format(image: Image.Image, text: str) -> tuple[dict, dict, str]: |
| """ |
| Performs image and text classification and formats the results. |
| Calls perform_text_classification_and_format for text classification. |
| |
| Args: |
| image: The input PIL Image. |
| text: The input text string. |
| |
| Returns: |
| A tuple containing: |
| - image_classification_results (dict): Probabilities for image classes. |
| - text_classification_probabilities (dict): Probabilities for text classes. |
| - text_classification_formatted (str): Formatted string of text classification results. |
| """ |
| |
| image_classification_results = {"error": "No image provided"} |
| |
|
|
| |
| design_stage = "unknown" |
| |
| if image is not None and model is not None and TFM is not None and device is not None and class_names: |
| try: |
| |
| img_tensor = TFM(image).unsqueeze(0).to(device) |
|
|
| |
| with torch.no_grad(): |
| img_output = model(img_tensor) |
|
|
| |
| img_probabilities = torch.softmax(img_output, dim=1)[0] |
| predicted_class_index = torch.argmax(img_probabilities).item() |
| design_stage = class_names[predicted_class_index] |
|
|
| |
| image_classification_results = {class_names[i]: float(img_probabilities[i]) for i in range(len(class_names))} |
|
|
| print(f"Image classified as: {design_stage}") |
| print(f"Image classification probabilities: {image_classification_results}") |
|
|
| except Exception as e: |
| print(f"Error processing image: {e}") |
| design_stage = "error during classification" |
| image_classification_results = {"error": f"Image classification failed: {e}"} |
| elif image is not None: |
| print("Image model components not loaded.") |
| design_stage = "model_not_loaded" |
| image_classification_results = {"error": "Image model or components not loaded"} |
| else: |
| print("No image provided for image classification.") |
| image_classification_results = {"info": "No image provided"} |
| design_stage = "no_image" |
|
|
|
|
| |
| |
| text_classification_probabilities, text_classification_formatted = perform_text_classification_and_format(text) |
| print(f"Text classification formatted result: {text_classification_formatted}") |
| print(f"Text classification raw probabilities: {text_classification_probabilities}") |
|
|
|
|
| |
| return image_classification_results, text_classification_probabilities, text_classification_formatted |
|
|
| print("perform_classification_and_format function defined.") |
|
|
|
|
| |
| def generate_prompt_only(image_classification_results: dict, text_classification_probabilities: dict, text: str) -> str: |
| """ |
| Generates a prompt for the LLM based on image and text classification results. |
| |
| Args: |
| image_classification_results: Dictionary of image class probabilities. |
| text_classification_probabilities: Dictionary of text class probabilities. |
| text: The original input text string. |
| |
| Returns: |
| A string containing the generated prompt for the LLM. |
| """ |
| |
| design_stage = "unknown" |
| if image_classification_results and "error" not in image_classification_results and "info" not in image_classification_results: |
| try: |
| |
| valid_results = {k: v for k, v in image_classification_results.items() if k not in ["error", "info"]} |
| if valid_results: |
| design_stage = max(valid_results, key=valid_results.get) |
| else: |
| design_stage = "unknown" |
| except Exception: |
| design_stage = "unknown" |
| elif "info" in image_classification_results: |
| design_stage = "no_image" |
| elif "error" in image_classification_results: |
| design_stage = "image_classification_failed" |
|
|
|
|
| |
| has_high_concept = "Cannot Determine" |
| if text_classification_probabilities and "error" not in text_classification_probabilities and "info" not in text_classification_probabilities: |
| try: |
| |
| high_concept_prob = text_classification_probabilities.get("High Concept", 0.0) |
| no_high_concept_prob = text_classification_probabilities.get("No High Concept", 0.0) |
| if high_concept_prob > no_high_concept_prob: |
| has_high_concept = "Yes" |
| else: |
| has_high_concept = "No" |
| except Exception: |
| has_high_concept = "Cannot Determine" |
| elif "info" in text_classification_probabilities: |
| has_high_concept = "no_text" |
| elif "error" in text_classification_probabilities: |
| has_high_concept = "text_classification_failed" |
|
|
|
|
| |
| |
| |
| |
| if design_stage in ["unknown", "no_image", "image_classification_failed"] or has_high_concept in ["Cannot Determine", "no_text", "text_classification_failed"]: |
| llm_attitude = llm_attitude_mapping.get("random", "neutral and informative") |
| else: |
| llm_attitude = llm_attitude_mapping.get(design_stage, llm_attitude_mapping.get("random", "neutral and informative")) |
|
|
|
|
| |
| prompt = f"""User is a low-level architecture student struggling with critical architectural reviews. You are an abstract architecture critique interpreter. Your response must be in English. |
| Given that the user is in the {design_stage} design stage, your attitude should be {llm_attitude}. |
| Given that the user input result (Yes/No) contains abstract architectural concepts: {has_high_concept}. |
| If the user input contains abstract architectural concepts, you need to explain the abstract concept to the user and then provide actionable advice. If not, you can directly provide actionable advice. |
| User input text content: {text} You need to explain abstract concepts to the user using language that a child can understand, provide examples from daily life, and offer actionable advice. |
| """ |
|
|
| return prompt |
|
|
| print("generate_prompt_only function defined.") |
|
|
|
|
| |
|
|
|
|
| |
| |
| examples = [ |
| |
| ["https://balancedarchitecture.com/wp-content/uploads/2021/11/EXISTING-FIRST-FLOOR-PRES-scaled-e1635965923983.jpg", "Exploring spatial relationships and material palettes."], |
| |
| ["https://cdn.prod.website-files.com/5894a32730554b620f7bf36d/5e848c2d622e7abe1ad48504_5e01ce9f0d272014d0353cd1_Things-You-Need-to-Organize-a-3D-Rendering-Architectural-Project-EASY-RENDER.jpeg", "The window size is too small."], |
| |
| ["https://architectelevator.com/assets/img/bilbao_sketch.png", "The facade expresses the building's relationship with the urban context."], |
| ] |
|
|
| with gr.Blocks() as demo_step_by_step: |
| gr.Markdown("# Architecture Feedback Generator (Classification & Prompt Only)") |
| gr.Markdown(""" |
| Upload an architectural image and provide a text description or question to see classification results and the generated prompt. |
| (LLM feedback generation is excluded from this version). |
| """) |
|
|
| with gr.Row(): |
| image_input = gr.Image(type="pil", label="Upload Architectural Image") |
| text_input = gr.Textbox(label="Enter Text Description or Question") |
|
|
| classify_and_prompt_button = gr.Button("Perform Classification & Generate Prompt") |
|
|
|
|
| with gr.Row(): |
| |
| image_output_label = gr.Label(num_top_classes=len(class_names) if 'class_names' in globals() and class_names else 5, label="Image Classification Results") |
| text_output_textbox = gr.Textbox(label="Text Classification Results") |
|
|
| |
| text_classification_probabilities_state = gr.State() |
|
|
| prompt_output_textbox = gr.Textbox(label="Generated Prompt for LLM", interactive=True) |
|
|
| |
|
|
|
|
| |
| |
| |
| |
| |
| classification_outputs = classify_and_prompt_button.click( |
| fn=perform_classification_and_format, |
| inputs=[image_input, text_input], |
| outputs=[image_output_label, text_classification_probabilities_state, text_output_textbox], |
| |
| ) |
|
|
| |
| |
| classification_outputs[2].then( |
| fn=generate_prompt_only, |
| inputs=[ |
| classification_outputs[0], |
| classification_outputs[1], |
| text_input |
| ], |
| outputs=prompt_output_textbox, |
| |
| ) |
|
|
| |
|
|
| |
| |
| def generate_full_chain_output_step_by_step(img, txt): |
| |
| img_res, txt_prob, txt_fmt = perform_classification_and_format(img, txt) |
| |
| prompt = generate_prompt_only(img_res, txt_prob, txt) |
| |
| |
| |
| return img_res, txt_fmt, prompt |
|
|
| |
| |
| |
| |
| gr.Examples( |
| examples=examples, |
| inputs=[image_input, text_input], |
| |
| outputs=[image_output_label, text_output_textbox, prompt_output_textbox], |
| fn=generate_full_chain_output_step_by_step, |
| cache_examples=False, |
| ) |
|
|
|
|
| |
| |
| demo_step_by_step.launch() |
|
|