Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as T | |
| import os | |
| import json | |
| import sentence_transformers | |
| from huggingface_hub import hf_hub_download | |
| import pickle | |
| import timm | |
| import google.generativeai as genai | |
| # ============================================ | |
| # 1. LOAD IMAGE CLASSIFICATION MODEL | |
| # ============================================ | |
| print("Loading image classification model...") | |
| REPO_ID = "keerthikoganti/architecture-design-stages-compact-cnn" | |
| pkl_path = hf_hub_download(repo_id=REPO_ID, filename="model_bundle.pkl") | |
| with open(pkl_path, "rb") as f: | |
| bundle = pickle.load(f) | |
| architecture = bundle["architecture"] | |
| num_classes = bundle["num_classes"] | |
| class_names = bundle["class_names"] | |
| state_dict = bundle["state_dict"] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = timm.create_model(architecture, pretrained=False, num_classes=num_classes) | |
| model.load_state_dict(state_dict) | |
| model.eval().to(device) | |
| TFM = T.Compose([ | |
| T.Resize(224), | |
| T.CenterCrop(224), | |
| T.ToTensor(), | |
| T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) | |
| ]) | |
| print("β Image classification model loaded successfully!") | |
| # ============================================ | |
| # 2. LOAD TEXT CLASSIFICATION MODEL | |
| # ============================================ | |
| print("Loading text classification model...") | |
| from autogluon.tabular import TabularPredictor | |
| import shutil | |
| text_repo_id = "kaitongg/my-autogluon-model" | |
| download_dir = "downloaded_predictor" | |
| if os.path.exists(download_dir): | |
| shutil.rmtree(download_dir) | |
| os.makedirs(download_dir, exist_ok=True) | |
| from huggingface_hub import snapshot_download | |
| downloaded_path = snapshot_download( | |
| repo_id=text_repo_id, | |
| repo_type="model", | |
| local_dir=download_dir, | |
| local_dir_use_symlinks=False, | |
| ) | |
| predictor_path = os.path.join(downloaded_path, "autogluon_predictor") | |
| # Bypass Python version check (model trained on 3.12, running on 3.10) | |
| loaded_predictor_from_hub = TabularPredictor.load( | |
| predictor_path, | |
| require_py_version_match=False, | |
| require_version_match=False | |
| ) | |
| embedding_model = sentence_transformers.SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| print("β Text classification model loaded successfully!") | |
| # ============================================ | |
| # 3. INITIALIZE GEMINI API | |
| # ============================================ | |
| print("Initializing Gemini API...") | |
| # Get API key from environment variable (set in Hugging Face Spaces secrets) | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| if GEMINI_API_KEY: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| # Try models in order of preference (using full model paths) | |
| model_names_to_try = [ | |
| 'models/gemini-2.5-flash', # Latest stable flash model | |
| 'models/gemini-flash-latest', # Auto-updates to latest | |
| 'models/gemini-2.0-flash', # Fallback to 2.0 | |
| 'models/gemini-pro-latest', # Pro version | |
| ] | |
| gemini_model = None | |
| for model_name in model_names_to_try: | |
| try: | |
| gemini_model = genai.GenerativeModel(model_name) | |
| # Test with a simple query | |
| test_response = gemini_model.generate_content("Test") | |
| print(f"β Gemini API initialized successfully with {model_name}!") | |
| break | |
| except Exception as e: | |
| print(f"Failed to load {model_name}: {str(e)[:100]}") | |
| continue | |
| if gemini_model is None: | |
| print("β οΈ Warning: Could not initialize any Gemini model") | |
| else: | |
| gemini_model = None | |
| print("β οΈ Warning: GEMINI_API_KEY not found in environment variables") | |
| # ============================================ | |
| # 4. LLM ATTITUDE MAPPING | |
| # ============================================ | |
| llm_attitude_mapping = { | |
| "brainstorm": "creative and encouraging", | |
| "design_iteration": "constructive and detailed, focusing on improvements", | |
| "design_optimization": "critical and focused on efficiency and refinement", | |
| "final_review": "thorough and critical, evaluating completeness and adherence to requirements", | |
| "random": "neutral and informative, perhaps suggesting a relevant stage", | |
| } | |
| # ============================================ | |
| # 5. TEXT CLASSIFICATION FUNCTION | |
| # ============================================ | |
| def perform_text_classification_and_format(text: str) -> tuple: | |
| text_classification_formatted = "No text provided" | |
| text_classification_probabilities = {"No High Concept": 0.0, "High Concept": 0.0} | |
| predicted_text_label = "0" | |
| if text and loaded_predictor_from_hub is not None and embedding_model is not None: | |
| try: | |
| embeddings = embedding_model.encode( | |
| [text], | |
| batch_size=1, | |
| show_progress_bar=False, | |
| convert_to_numpy=True, | |
| normalize_embeddings=False, | |
| ) | |
| n, d = embeddings.shape | |
| text_df_processed = pd.DataFrame(embeddings, columns=[f"e{i}" for i in range(d)]) | |
| text_proba_df = loaded_predictor_from_hub.predict_proba(text_df_processed) | |
| text_classification_probabilities = { | |
| "No High Concept": float(text_proba_df.iloc[0].get("0", 0.0)), | |
| "High Concept": float(text_proba_df.iloc[0].get("1", 0.0)), | |
| } | |
| predicted_text_label = str(loaded_predictor_from_hub.predict(text_df_processed).iloc[0]) | |
| if predicted_text_label == "1": | |
| has_high_concept = "Yes" | |
| confidence = text_classification_probabilities["High Concept"] | |
| else: | |
| has_high_concept = "No" | |
| confidence = text_classification_probabilities["No High Concept"] | |
| text_classification_formatted = f"High Concept: {has_high_concept} (Confidence: {confidence:.2f})" | |
| except Exception as e: | |
| print(f"Error processing text: {e}") | |
| text_classification_formatted = f"Text classification failed: {e}" | |
| return text_classification_formatted, text_classification_probabilities, predicted_text_label | |
| # ============================================ | |
| # 6. COMBINED CLASSIFICATION FUNCTION | |
| # ============================================ | |
| def perform_classification_and_format(image: Image.Image, text: str) -> tuple: | |
| image_classification_results = {"error": "No image provided"} | |
| design_stage = "unknown" | |
| if image is not None and model is not None: | |
| try: | |
| img_tensor = TFM(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| img_output = model(img_tensor) | |
| img_probabilities = torch.softmax(img_output, dim=1)[0] | |
| predicted_class_index = torch.argmax(img_probabilities).item() | |
| design_stage = class_names[predicted_class_index] | |
| image_classification_results = { | |
| class_names[i]: float(img_probabilities[i]) | |
| for i in range(len(class_names)) | |
| } | |
| print(f"β Image classified as: {design_stage}") | |
| except Exception as e: | |
| print(f"β Error processing image: {e}") | |
| image_classification_results = {"error": f"Image classification failed: {e}"} | |
| text_classification_formatted, text_classification_probabilities, predicted_text_label = perform_text_classification_and_format(text) | |
| return image_classification_results, text_classification_probabilities, text_classification_formatted, predicted_text_label | |
| # ============================================ | |
| # 7. PROMPT GENERATION FUNCTION | |
| # ============================================ | |
| def generate_prompt_only(image_classification_results: dict, | |
| text_classification_probabilities: dict, | |
| predicted_text_label: str, | |
| text: str) -> str: | |
| design_stage = "unknown" | |
| if image_classification_results and "error" not in image_classification_results: | |
| try: | |
| design_stage = max(image_classification_results, key=image_classification_results.get) | |
| except Exception: | |
| design_stage = "unknown" | |
| has_high_concept = "Unable to determine" | |
| confidence = 0.0 | |
| if text_classification_probabilities and "error" not in text_classification_probabilities: | |
| try: | |
| if predicted_text_label == "1": | |
| has_high_concept = "Yes" | |
| confidence = text_classification_probabilities.get("High Concept", 0.0) | |
| else: | |
| has_high_concept = "No" | |
| confidence = text_classification_probabilities.get("No High Concept", 0.0) | |
| except Exception: | |
| has_high_concept = "Unable to determine" | |
| confidence = 0.0 | |
| llm_attitude = llm_attitude_mapping.get(design_stage, llm_attitude_mapping["random"]) | |
| prompt = f"""You are an architecture education assistant helping a student understand architectural concepts. | |
| Context: | |
| - Design stage: {design_stage} | |
| - Your feedback style should be: {llm_attitude} | |
| - Abstract concepts detected: {has_high_concept} (confidence: {confidence:.2f}) | |
| Student's input: "{text}" | |
| Instructions: | |
| Please provide educational feedback in 250-350 words that: | |
| 1. Uses simple, everyday examples and analogies | |
| 2. Explains any abstract architectural concepts in accessible language | |
| 3. Provides specific, actionable suggestions for improvement | |
| 4. Maintains an encouraging yet constructive tone | |
| 5. Ends with a complete sentence | |
| Focus on being helpful and educational rather than critical. | |
| """ | |
| return prompt | |
| # ============================================ | |
| # 8. GEMINI FEEDBACK GENERATION | |
| # ============================================ | |
| def generate_feedback_from_prompt(prompt_input: str) -> str: | |
| if gemini_model is None: | |
| return "β οΈ Gemini API not configured. Please set GEMINI_API_KEY in Hugging Face Spaces secrets." | |
| try: | |
| print("Generating feedback with Gemini...") | |
| # Extract just the user's input text | |
| user_text = prompt_input | |
| if "Student's input:" in prompt_input: | |
| parts = prompt_input.split("Student's input:") | |
| if len(parts) > 1: | |
| user_text = parts[1].strip().strip('"') | |
| # Ultra-simplified prompt - just the core request | |
| simple_prompt = f"Provide brief educational feedback on this architectural description: {user_text}" | |
| print(f"Sending prompt ({len(simple_prompt)} chars)") | |
| # Minimal configuration - only what's absolutely necessary | |
| response = gemini_model.generate_content(simple_prompt) | |
| print(f"Response received") | |
| # Extract text | |
| llm_response_text = None | |
| try: | |
| llm_response_text = response.text | |
| print(f"β Got text ({len(llm_response_text)} chars)") | |
| except Exception as e: | |
| print(f"Failed to get text: {str(e)[:100]}") | |
| # Try alternative extraction | |
| if response.candidates and response.candidates[0].content: | |
| candidate = response.candidates[0] | |
| if candidate.content.parts: | |
| texts = [part.text for part in candidate.content.parts if hasattr(part, 'text')] | |
| if texts: | |
| llm_response_text = "".join(texts) | |
| print(f"β Got text from parts ({len(llm_response_text)} chars)") | |
| if not llm_response_text: | |
| return "β οΈ No response generated. This may be an API limitation. Try:\n- Shorter, simpler descriptions\n- Removing technical terms\n- Testing with basic input like 'large windows'" | |
| return llm_response_text.strip() | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"Error: {error_msg[:200]}") | |
| return f"Error: {error_msg}" | |
| # ============================================ | |
| # 9. GRADIO INTERFACE | |
| # ============================================ | |
| examples = [ | |
| ["https://huggingface.co/datasets/kaitongg/image/resolve/main/5e848c2d622e7abe1ad48504_5e01ce9f0d272014d0353cd1_Things-You-Need-to-Organize-a-3D-Rendering-Architectural-Project-EASY-RENDER.jpeg", | |
| "Exploring spatial relationships and material palettes."], | |
| ["https://huggingface.co/datasets/kaitongg/image/resolve/main/EXISTING-FIRST-FLOOR-PRES-scaled-e1635965923983.jpg", | |
| "The window size is too small."], | |
| ["https://huggingface.co/datasets/kaitongg/image/resolve/main/bilbao_sketch.png", | |
| "The facade expresses the building's relationship with the urban context."], | |
| ] | |
| with gr.Blocks(css=""" | |
| .left-column, .middle-column, .right-column {min-width: 300px !important;} | |
| .textbox-container textarea {min-height: 150px !important;} | |
| """) as demo: | |
| gr.Markdown("# ποΈ Architecture Feedback Generator (Powered by Gemini)") | |
| gr.Markdown(""" | |
| Upload an architectural image and provide a text description or question. | |
| The system will classify the design stage, analyze the text for high-level concepts, | |
| generate a customized prompt, and provide AI-powered feedback using Google's Gemini. | |
| """) | |
| with gr.Row(): | |
| # LEFT COLUMN - Input Section | |
| with gr.Column(scale=1, elem_classes="left-column"): | |
| gr.Markdown("### π₯ Input") | |
| image_input = gr.Image(type="pil", label="Upload Architectural Image", height=300) | |
| text_input = gr.Textbox( | |
| label="Enter Text Description or Question", | |
| placeholder="Describe your architectural design, ask questions, or provide context...", | |
| lines=6, | |
| elem_classes="textbox-container" | |
| ) | |
| classify_button = gr.Button("π Classify & Generate Prompt", variant="primary", size="lg") | |
| # MIDDLE COLUMN - Classification & Prompt Section | |
| with gr.Column(scale=1, elem_classes="middle-column"): | |
| gr.Markdown("### π Classification Results & Prompt") | |
| image_output_label = gr.Label( | |
| num_top_classes=len(class_names), | |
| label="Image Classification (Design Stage)" | |
| ) | |
| text_output_textbox = gr.Textbox( | |
| label="Text Classification (High Concept Detection)", | |
| lines=2, | |
| elem_classes="textbox-container" | |
| ) | |
| prompt_output_textbox = gr.Textbox( | |
| label="Generated Prompt (Editable)", | |
| lines=10, | |
| interactive=True, | |
| elem_classes="textbox-container" | |
| ) | |
| generate_feedback_button = gr.Button("β¨ Generate AI Feedback", variant="primary", size="lg") | |
| # RIGHT COLUMN - Gemini Output Section | |
| with gr.Column(scale=1, elem_classes="right-column"): | |
| gr.Markdown("### π€ AI-Generated Feedback") | |
| llm_output_text = gr.Textbox( | |
| label="Gemini Response", | |
| lines=20, | |
| elem_classes="textbox-container", | |
| show_copy_button=True | |
| ) | |
| # Hidden state variables | |
| text_classification_probabilities_state = gr.State() | |
| predicted_text_label_state = gr.State() | |
| # Step 1: Classification | |
| classification_outputs = classify_button.click( | |
| fn=perform_classification_and_format, | |
| inputs=[image_input, text_input], | |
| outputs=[ | |
| image_output_label, | |
| text_classification_probabilities_state, | |
| text_output_textbox, | |
| predicted_text_label_state | |
| ] | |
| ) | |
| # Step 2: Generate Prompt | |
| def generate_prompt_wrapper(img_res, txt_prob, predicted_label, txt): | |
| return generate_prompt_only(img_res, txt_prob, predicted_label, txt) | |
| classification_outputs.then( | |
| fn=generate_prompt_wrapper, | |
| inputs=[ | |
| image_output_label, | |
| text_classification_probabilities_state, | |
| predicted_text_label_state, | |
| text_input | |
| ], | |
| outputs=prompt_output_textbox | |
| ) | |
| # Step 3: Gemini Feedback | |
| generate_feedback_button.click( | |
| fn=generate_feedback_from_prompt, | |
| inputs=[prompt_output_textbox], | |
| outputs=llm_output_text | |
| ) | |
| # Examples Section | |
| gr.Markdown("---") | |
| gr.Markdown("### π‘ Example Inputs") | |
| def generate_full_chain_output(img, txt): | |
| img_res, txt_prob, txt_fmt, predicted_label = perform_classification_and_format(img, txt) | |
| prompt = generate_prompt_only(img_res, txt_prob, predicted_label, txt) | |
| llm_res = generate_feedback_from_prompt(prompt) | |
| return img_res, txt_fmt, prompt, llm_res | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image_input, text_input], | |
| outputs=[image_output_label, text_output_textbox, prompt_output_textbox, llm_output_text], | |
| fn=generate_full_chain_output, | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |