English
Chinese
kaitongg's picture
Upload app.py
300f260 verified
import gradio as gr
import pandas as pd
from PIL import Image
import torch
import torchvision.transforms as T
import json
import sentence_transformers
import os
import tempfile
import shutil
# Removed google.generativeai import as Gemini is excluded
# --- Model Loading (Consolidated from previous cells, excluding Gemini) ---
# Load Hugging Face Token (Needed for private repos or some operations)
# In Hugging Face Spaces, secrets are accessed via environment variables
# HF_TOKEN = os.environ.get('HF_TOKEN_WRITE') # Commented out - usually not needed for public model downloads
# Load Image Classification Model (from TTx28yjzHMgR)
try:
from huggingface_hub import hf_hub_download
import pickle
import timm # Ensure timm is imported if used
REPO_ID_IMG = "keerthikoganti/architecture-design-stages-compact-cnn"
pkl_path = hf_hub_download(repo_id=REPO_ID_IMG, filename="model_bundle.pkl")
with open(pkl_path, "rb") as f:
bundle = pickle.load(f)
architecture = bundle["architecture"]
num_classes = bundle["num_classes"]
class_names = bundle["class_names"]
state_dict = bundle["state_dict"]
device = "cuda" if torch.cuda.is_available() else "cpu"
model = timm.create_model(architecture, pretrained=False, num_classes=num_classes)
model.load_state_dict(state_dict)
model.eval().to(device)
TFM = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
print("Image Classification Model loaded successfully!")
except Exception as e:
print(f"Error loading Image Classification Model: {e}")
model = None
TFM = None
device = None
class_names = []
# Load Text Classification Model (from VysWLxnGItBa)
try:
from huggingface_hub import snapshot_download
from autogluon.tabular import TabularPredictor
import os # Ensure os is imported
repo_id_text = "kaitongg/my-autogluon-model"
download_dir = "downloaded_predictor"
# Download the entire model repository
print(f"Downloading text model files from {repo_id_text}...")
# Use HF_TOKEN if the repo is private: token=os.environ.get('HF_TOKEN_WRITE')
downloaded_path = snapshot_download(
repo_id=repo_id_text,
repo_type="model",
local_dir=download_dir,
local_dir_use_symlinks=False,
# token=HF_TOKEN # Uncomment if repo is private and HF_TOKEN is needed
)
print(f"Text model files downloaded to: {downloaded_path}")
# Load the predictor from the subdirectory 'autogluon_predictor'
predictor_path = os.path.join(downloaded_path, "autogluon_predictor")
loaded_predictor_from_hub = TabularPredictor.load(predictor_path)
print("Text Classification Model loaded successfully from Hugging Face Hub!")
except Exception as e:
print(f"Error loading Text Classification Model: {e}")
loaded_predictor_from_hub = None
# Load Sentence Transformer Model (from OJ9wke1CrK1S/global scope)
try:
embedding_model = sentence_transformers.SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
print("Sentence Transformer model loaded successfully!")
except Exception as e:
print(f"Error loading Sentence Transformer model: {e}")
embedding_model = None
# --- LLM Attitude Mapping (from 74905474) ---
llm_attitude_mapping = {
"brainstorm": "creative and encouraging",
"design_iteration": "constructive and detailed, focusing on improvements",
"design_optimization": "critical and focused on efficiency and refinement",
"final_review": "thorough and critical, evaluating completeness and adherence to requirements",
"random": "neutral and informative, perhaps suggesting a relevant stage",
}
print("LLM attitude mapping defined successfully!")
# --- Function Definitions (Consolidated from jKIkOPByaN3Z and OJ9wke1CrK1S) ---
# Define the specific text classification function (from OJ9wke1CrK1S/jKIkOPByaN3Z)
def perform_text_classification_and_format(text: str) -> tuple[dict, str]:
"""
Performs text classification using the loaded predictor and embedding model,
and formats the results.
Args:
text: The input text string.
Returns:
A tuple containing:
- text_classification_probabilities (dict): Probabilities for each class.
- text_classification_formatted (str): Formatted string of classification results.
"""
text_classification_probabilities = {"error": "No text provided or model not loaded"}
text_classification_formatted = "No text provided or model not loaded"
has_high_concept = "Cannot Determine" # Translated
confidence = 0.0
# Check if models are loaded before proceeding
if text and loaded_predictor_from_hub is not None and embedding_model is not None:
try:
# Encode the text using the embedding model
embeddings = embedding_model.encode(
[text],
batch_size=1,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=False,
)
# Create a DataFrame with 'eX' column names from embeddings
n, d = embeddings.shape
text_df_processed = pd.DataFrame(embeddings, columns=[f"e{i}" for i in range(d)])
# Get text model prediction probabilities
text_proba_df = loaded_predictor_from_hub.predict_proba(text_df_processed)
# Assuming your predictor returns probabilities for class 0 and class 1
text_classification_probabilities = {
"No High Concept": float(text_proba_df.iloc[0]["0"]) if "0" in text_proba_df.columns else 0.0,
"High Concept": float(text_proba_df.iloc[0]["1"]) if "1" in text_proba_df.columns else 0.0,
}
# Determine the predicted class label (0 or 1) as a string
if not text_proba_df.empty and len(text_proba_df.columns) > 0:
predicted_text_label = str(loaded_predictor_from_hub.predict(text_df_processed).iloc[0])
# Correctly compare the predicted label as a string
if predicted_text_label == "1":
has_high_concept = "Yes" # Translated
confidence = text_classification_probabilities.get("High Concept", 0.0)
elif predicted_text_label == "0":
has_high_concept = "No" # Translated
confidence = text_classification_probabilities.get("No High Concept", 0.0)
else: # Handle unexpected labels
has_high_concept = f"Unknown Label: {predicted_text_label}" # Translated
confidence = 0.0
print(f"Warning: Predictor returned unexpected label: {predicted_text_label}")
else:
has_high_concept = "Cannot Determine (No Prediction Output)" # Translated
print(f"Text classified as having high concept: {has_high_concept}")
print(f"Text classification probabilities: {text_classification_probabilities}")
# Format the text classification results for display
text_classification_formatted = f"High Concept: {has_high_concept} (Confidence: {confidence:.2f})"
except Exception as e:
print(f"Error during text classification: {e}")
text_classification_probabilities = {"error": f"Text classification failed: {e}"}
text_classification_formatted = f"Text classification failed: {e}"
elif text:
print("Text predictor or embedding model not loaded for text classification.")
text_classification_probabilities = {"error": "Text predictor or embedding model not loaded"}
text_classification_formatted = "Text predictor or embedding model not loaded."
elif loaded_predictor_from_hub is None:
print("Text predictor model not loaded for text classification.")
text_classification_probabilities = {"error": "Text predictor model not loaded"}
text_classification_formatted = "Text predictor model not loaded."
else: # text is None or empty
text_classification_probabilities = {"info": "No text provided"}
text_classification_formatted = "No text provided"
return text_classification_probabilities, text_classification_formatted
print("perform_text_classification_and_format function defined.")
# Define the combined classification function (from jKIkOPByaN3Z)
# This function calls perform_text_classification_and_format defined above
def perform_classification_and_format(image: Image.Image, text: str) -> tuple[dict, dict, str]:
"""
Performs image and text classification and formats the results.
Calls perform_text_classification_and_format for text classification.
Args:
image: The input PIL Image.
text: The input text string.
Returns:
A tuple containing:
- image_classification_results (dict): Probabilities for image classes.
- text_classification_probabilities (dict): Probabilities for text classes.
- text_classification_formatted (str): Formatted string of text classification results.
"""
# Initialize output variables with default values
image_classification_results = {"error": "No image provided"}
# Text classification results will be obtained from perform_text_classification_and_format
# --- Process Image Input ---
design_stage = "unknown"
# Check if image model components are loaded
if image is not None and model is not None and TFM is not None and device is not None and class_names:
try:
# Apply the transformation
img_tensor = TFM(image).unsqueeze(0).to(device)
# Get the image model output
with torch.no_grad():
img_output = model(img_tensor)
# Get probabilities and predict the design stage
img_probabilities = torch.softmax(img_output, dim=1)[0]
predicted_class_index = torch.argmax(img_probabilities).item()
design_stage = class_names[predicted_class_index]
# Create a dictionary of class names and probabilities for Gradio Label output
image_classification_results = {class_names[i]: float(img_probabilities[i]) for i in range(len(class_names))}
print(f"Image classified as: {design_stage}")
print(f"Image classification probabilities: {image_classification_results}")
except Exception as e:
print(f"Error processing image: {e}")
design_stage = "error during classification"
image_classification_results = {"error": f"Image classification failed: {e}"}
elif image is not None:
print("Image model components not loaded.")
design_stage = "model_not_loaded"
image_classification_results = {"error": "Image model or components not loaded"}
else: # image is None
print("No image provided for image classification.")
image_classification_results = {"info": "No image provided"}
design_stage = "no_image"
# --- Process Text Input using the dedicated function ---
# perform_text_classification_and_format is defined above and returns (probabilities_dict, formatted_string)
text_classification_probabilities, text_classification_formatted = perform_text_classification_and_format(text)
print(f"Text classification formatted result: {text_classification_formatted}")
print(f"Text classification raw probabilities: {text_classification_probabilities}")
# Return image classification probabilities (dict), text classification probabilities (dict), and formatted text classification string
return image_classification_results, text_classification_probabilities, text_classification_formatted
print("perform_classification_and_format function defined.")
# Define a function to generate the prompt based on classification results and text (from jKIkOPByaN3Z)
def generate_prompt_only(image_classification_results: dict, text_classification_probabilities: dict, text: str) -> str:
"""
Generates a prompt for the LLM based on image and text classification results.
Args:
image_classification_results: Dictionary of image class probabilities.
text_classification_probabilities: Dictionary of text class probabilities.
text: The original input text string.
Returns:
A string containing the generated prompt for the LLM.
"""
# Extract design stage from image classification results
design_stage = "unknown"
if image_classification_results and "error" not in image_classification_results and "info" not in image_classification_results:
try:
# Find the class with the highest probability, excluding error/info keys
valid_results = {k: v for k, v in image_classification_results.items() if k not in ["error", "info"]}
if valid_results:
design_stage = max(valid_results, key=valid_results.get)
else:
design_stage = "unknown" # Fallback if no valid results
except Exception:
design_stage = "unknown"
elif "info" in image_classification_results:
design_stage = "no_image" # Special case if no image was provided
elif "error" in image_classification_results:
design_stage = "image_classification_failed" # Special case if image classification failed
# Extract high concept status from text classification probabilities
has_high_concept = "Cannot Determine" # Translated
if text_classification_probabilities and "error" not in text_classification_probabilities and "info" not in text_classification_probabilities:
try:
# Determine has_high_concept based on which probability is higher
high_concept_prob = text_classification_probabilities.get("High Concept", 0.0)
no_high_concept_prob = text_classification_probabilities.get("No High Concept", 0.0)
if high_concept_prob > no_high_concept_prob:
has_high_concept = "Yes" # Translated
else:
has_high_concept = "No" # Translated
except Exception:
has_high_concept = "Cannot Determine" # Translated
elif "info" in text_classification_probabilities:
has_high_concept = "no_text" # Special case if no text was provided
elif "error" in text_classification_probabilities:
has_high_concept = "text_classification_failed" # Special case if text classification failed
# --- Generate Dynamic Prompt for LLM ---
# Note: The prompt is still generated, but the LLM interaction part is removed.
# The prompt structure is based on previous requirements.
# Use a default attitude if design_stage or has_high_concept are special error/info states
if design_stage in ["unknown", "no_image", "image_classification_failed"] or has_high_concept in ["Cannot Determine", "no_text", "text_classification_failed"]: # Translated
llm_attitude = llm_attitude_mapping.get("random", "neutral and informative") # Use random or a default neutral attitude
else:
llm_attitude = llm_attitude_mapping.get(design_stage, llm_attitude_mapping.get("random", "neutral and informative"))
# Translated prompt components
prompt = f"""User is a low-level architecture student struggling with critical architectural reviews. You are an abstract architecture critique interpreter. Your response must be in English.
Given that the user is in the {design_stage} design stage, your attitude should be {llm_attitude}.
Given that the user input result (Yes/No) contains abstract architectural concepts: {has_high_concept}.
If the user input contains abstract architectural concepts, you need to explain the abstract concept to the user and then provide actionable advice. If not, you can directly provide actionable advice.
User input text content: {text} You need to explain abstract concepts to the user using language that a child can understand, provide examples from daily life, and offer actionable advice.
""" # Use full text input
return prompt
print("generate_prompt_only function defined.")
# Removed generate_feedback_from_prompt function as Gemini LLM is excluded
# --- Create Gradio Interface (Consolidated from jKIkOPByaN3Z, excluding LLM feedback parts) ---
# Define example inputs for the Gradio interface
examples = [
# Example 1: Brainstorm stage, text with high concept
["https://balancedarchitecture.com/wp-content/uploads/2021/11/EXISTING-FIRST-FLOOR-PRES-scaled-e1635965923983.jpg", "Exploring spatial relationships and material palettes."],
# Example 2: Design Iteration stage, text without high concept
["https://cdn.prod.website-files.com/5894a32730554b620f7bf36d/5e848c2d622e7abe1ad48504_5e01ce9f0d272014d0353cd1_Things-You-Need-to-Organize-a-3D-Rendering-Architectural-Project-EASY-RENDER.jpeg", "The window size is too small."],
# Example 3: Final Review stage, text with some concept
["https://architectelevator.com/assets/img/bilbao_sketch.png", "The facade expresses the building's relationship with the urban context."],
]
with gr.Blocks() as demo_step_by_step:
gr.Markdown("# Architecture Feedback Generator (Classification & Prompt Only)") # Translated
gr.Markdown("""
Upload an architectural image and provide a text description or question to see classification results and the generated prompt.
(LLM feedback generation is excluded from this version).
""")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Architectural Image") # Translated
text_input = gr.Textbox(label="Enter Text Description or Question") # Translated
classify_and_prompt_button = gr.Button("Perform Classification & Generate Prompt") # Translated
with gr.Row():
# Assuming class_names is loaded, otherwise provide a default like 5
image_output_label = gr.Label(num_top_classes=len(class_names) if 'class_names' in globals() and class_names else 5, label="Image Classification Results") # Translated
text_output_textbox = gr.Textbox(label="Text Classification Results") # Translated
# Use gr.State components to store intermediate results needed for subsequent steps
text_classification_probabilities_state = gr.State()
prompt_output_textbox = gr.Textbox(label="Generated Prompt for LLM", interactive=True) # Translated - Allow user to inspect/edit prompt
# Removed LLM feedback output component and button
# Define the event chain
# 1. When classify_and_prompt_button is clicked, perform classification and format results
# perform_classification_and_format returns:
# (image_classification_results, text_classification_probabilities, text_classification_formatted)
# Map outputs to image_output_label, text_classification_probabilities_state, and text_output_textbox
classification_outputs = classify_and_prompt_button.click(
fn=perform_classification_and_format,
inputs=[image_input, text_input],
outputs=[image_output_label, text_classification_probabilities_state, text_output_textbox], # Corrected outputs list
# queue=False # Consider if queuing is needed
)
# 2. Then, use the outputs of the first step to generate and display the prompt
# Trigger when any of the classification outputs are updated. Use the State component for text probs.
classification_outputs[2].then( # Trigger when text_output_textbox (output[2]) is updated
fn=generate_prompt_only,
inputs=[
classification_outputs[0], # References the output component holding img_res
classification_outputs[1], # References the State component holding txt_prob
text_input # Original text input component
],
outputs=prompt_output_textbox,
# queue=False # Consider if queuing is needed
)
# Removed LLM feedback generation button click event
# Add examples - Examples should trigger the classification -> prompt generation chain
# This requires a function that performs both steps for a given example input.
def generate_full_chain_output_step_by_step(img, txt):
# Step 1: Classification
img_res, txt_prob, txt_fmt = perform_classification_and_format(img, txt)
# Step 2: Prompt Generation
prompt = generate_prompt_only(img_res, txt_prob, txt)
# Return the outputs expected by gr.Examples outputs
# The outputs for examples are: image_output_label, text_output_textbox, prompt_output_textbox.
# Need to return img_res, txt_fmt, prompt in that order.
return img_res, txt_fmt, prompt
# Note: The examples outputs need to match the outputs of the fn.
# The outputs from generate_full_chain_output_step_by_step are img_res, txt_fmt, prompt.
# The Gradio outputs defined are image_output_label, text_output_textbox, prompt_output_textbox.
# The order should match.
gr.Examples(
examples=examples,
inputs=[image_input, text_input],
# Outputs to update for examples: Image Classification, Text Classification, Prompt
outputs=[image_output_label, text_output_textbox, prompt_output_textbox],
fn=generate_full_chain_output_step_by_step,
cache_examples=False, # Set to False to re-run the function on example click
)
# Launch the interface
# if __name__ == "__main__": # Remove this block for deployment to Spaces
demo_step_by_step.launch()