kaitongg commited on
Commit
3377508
·
verified ·
1 Parent(s): ec726a6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import json
4
+ import zipfile
5
+ import torch
6
+ import timm
7
+ import pickle
8
+ import gradio as gr
9
+ import pandas as pd
10
+ import sentence_transformers
11
+ import torchvision.transforms as T
12
+ from PIL import Image
13
+ from autogluon.tabular import TabularPredictor
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ from llama_cpp import Llama
16
+
17
+ # ----------------------
18
+ # Load Image Classification Model
19
+ # ----------------------
20
+ REPO_ID = "keerthikoganti/architecture-design-stages-compact-cnn"
21
+ pkl_path = hf_hub_download(repo_id=REPO_ID, filename="model_bundle.pkl")
22
+ with open(pkl_path, "rb") as f:
23
+ bundle = pickle.load(f)
24
+
25
+ architecture = bundle["architecture"]
26
+ num_classes = bundle["num_classes"]
27
+ class_names = bundle["class_names"]
28
+ state_dict = bundle["state_dict"]
29
+
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ model = timm.create_model(architecture, pretrained=False, num_classes=num_classes)
32
+ model.load_state_dict(state_dict)
33
+ model.eval().to(device)
34
+
35
+ TFM = T.Compose([
36
+ T.Resize(224),
37
+ T.CenterCrop(224),
38
+ T.ToTensor(),
39
+ T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
40
+ ])
41
+
42
+ # ----------------------
43
+ # Load Text Classification Model
44
+ # ----------------------
45
+ repo_id = "kaitongg/my-autogluon-model"
46
+ download_dir = "downloaded_predictor"
47
+ if os.path.exists(download_dir):
48
+ shutil.rmtree(download_dir)
49
+ os.makedirs(download_dir, exist_ok=True)
50
+
51
+ snapshot_download(
52
+ repo_id=repo_id,
53
+ repo_type="model",
54
+ local_dir=download_dir,
55
+ local_dir_use_symlinks=False,
56
+ )
57
+
58
+ predictor_path = os.path.join(download_dir, "autogluon_predictor")
59
+ loaded_predictor_from_hub = TabularPredictor.load(predictor_path)
60
+
61
+ # ----------------------
62
+ # Load LLM
63
+ # ----------------------
64
+ llm_model_id = "bartowski/Qwen_Qwen3-4B-Instruct-2507-GGUF"
65
+ llm_filename = "Qwen_Qwen3-4B-Instruct-2507-Q4_K_M.gguf"
66
+
67
+ llm = Llama.from_pretrained(
68
+ repo_id=llm_model_id,
69
+ filename=llm_filename,
70
+ n_ctx=4096,
71
+ n_threads=None,
72
+ logits_all=False,
73
+ verbose=False,
74
+ )
75
+
76
+ llm_attitude_mapping = {
77
+ "brainstorm": "creative and encouraging",
78
+ "design_iteration": "constructive and detailed, focusing on improvements",
79
+ "design_optimization": "critical and focused on efficiency and refinement",
80
+ "final_review": "thorough and critical, evaluating completeness and adherence to requirements",
81
+ "random": "neutral and informative, perhaps suggesting a relevant stage",
82
+ }
83
+
84
+ # ----------------------
85
+ # Load Embedding Model
86
+ # ----------------------
87
+ try:
88
+ embedding_model = sentence_transformers.SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
89
+ except Exception:
90
+ embedding_model = None
91
+
92
+ # ----------------------
93
+ # Functions
94
+ # ----------------------
95
+ def perform_text_classification_and_format(text: str):
96
+ text_classification_formatted = "No text provided"
97
+ text_classification_probabilities = {}
98
+ predicted_text_label = "0"
99
+
100
+ if text and loaded_predictor_from_hub is not None and embedding_model is not None:
101
+ embeddings = embedding_model.encode([text], convert_to_numpy=True)
102
+ n, d = embeddings.shape
103
+ text_df_processed = pd.DataFrame(embeddings, columns=[f"e{i}" for i in range(d)])
104
+
105
+ text_proba_df = loaded_predictor_from_hub.predict_proba(text_df_processed)
106
+ text_classification_probabilities = {
107
+ "No High Concept": float(text_proba_df.iloc[0].get("0", 0.0)),
108
+ "High Concept": float(text_proba_df.iloc[0].get("1", 0.0)),
109
+ }
110
+
111
+ predicted_text_label = str(loaded_predictor_from_hub.predict(text_df_processed).iloc[0])
112
+ if predicted_text_label == "1":
113
+ has_high_concept = "是"
114
+ confidence = text_classification_probabilities["High Concept"]
115
+ else:
116
+ has_high_concept = "否"
117
+ confidence = text_classification_probabilities["No High Concept"]
118
+
119
+ text_classification_formatted = f"High Concept: {has_high_concept} (Confidence: {confidence:.2f})"
120
+
121
+ return text_classification_formatted, text_classification_probabilities, predicted_text_label
122
+
123
+ def perform_classification_and_format(image: Image.Image, text: str):
124
+ image_classification_results = {"error": "No image provided"}
125
+ design_stage = "unknown"
126
+
127
+ if image is not None:
128
+ img_tensor = TFM(image).unsqueeze(0).to(device)
129
+ with torch.no_grad():
130
+ img_output = model(img_tensor)
131
+ img_probabilities = torch.softmax(img_output, dim=1)[0]
132
+ predicted_class_index = torch.argmax(img_probabilities).item()
133
+ design_stage = class_names[predicted_class_index]
134
+ image_classification_results = {class_names[i]: float(img_probabilities[i]) for i in range(len(class_names))}
135
+
136
+ text_classification_formatted, text_classification_probabilities, predicted_text_label = perform_text_classification_and_format(text)
137
+ return image_classification_results, text_classification_probabilities, text_classification_formatted
138
+
139
+ def generate_prompt_only(image_classification_results, text_classification_probabilities, predicted_text_label, text: str):
140
+ design_stage = "unknown"
141
+ if image_classification_results and "error" not in image_classification_results:
142
+ design_stage = max(image_classification_results, key=image_classification_results.get)
143
+
144
+ has_high_concept = "否"
145
+ confidence = text_classification_probabilities.get("No High Concept", 0.0)
146
+ if predicted_text_label == "1":
147
+ has_high_concept = "是"
148
+ confidence = text_classification_probabilities.get("High Concept", 0.0)
149
+
150
+ llm_attitude = llm_attitude_mapping.get(design_stage, llm_attitude_mapping["random"])
151
+
152
+ prompt = f"""You are an abstract architecture critique interpreter.
153
+ Your audience is a low-level architecture student.
154
+ 已知用户处于{design_stage}设计阶段,所以你的态度应该要{llm_attitude}。
155
+ 已知用户输入的结果(是/否)含有抽象建筑学概念:{has_high_concept}。
156
+ 牢记规则:
157
+ - 撰写一段英文,严格控制在250-350字。
158
+ - 文末必须以完整句子收尾。
159
+ - 不得重复任何观点或句子。
160
+ - 禁止使用警句、口号或平行句式。
161
+ - 不得出现“最终输出”、‘输出结束’、“无后续文本”等元注释。
162
+ - 禁止添加自我反思或系统性备注。
163
+ - 段落末句结束后立即终止输出。
164
+ 以下是用户输入的文本内容:{text}你需要用儿童都懂的语言,举生活中的例子给用户解释抽象概念,并且给出可操作的建议。
165
+ """
166
+ return prompt
167
+
168
+ def generate_feedback_from_prompt(prompt_input: str):
169
+ llm_response_text = "Error generating feedback from LLM."
170
+ if llm is not None:
171
+ output = llm.create_completion(
172
+ prompt=prompt_input,
173
+ max_tokens=350,
174
+ stop=["\n\n","<|im_end|>","Final", "Output", "No more"],
175
+ temperature=0.7,
176
+ )
177
+ if output and 'choices' in output and len(output['choices']) > 0 and 'text' in output['choices'][0]:
178
+ llm_response_text = output['choices'][0]['text'].strip()
179
+ return llm_response_text
180
+
181
+ # ----------------------
182
+ # Gradio Interface
183
+ # ----------------------
184
+ examples = [
185
+ ["https://balancedarchitecture.com/wp-content/uploads/2021/11/EXISTING-FIRST-FLOOR-PRES-scaled-e1635965923983.jpg", "Exploring spatial relationships and material palettes."],
186
+ ["https://cdn.prod.website-files.com/5894a32730554b620f7bf36d/5e848c2d622e7abe1ad48504_5e01ce9f0d272014d0353cd1_Things-You-Need-to-Organize-a-3D-Rendering-Architectural-Project-EASY-RENDER.jpeg", "The window size is too small."],
187
+ ["https://architectelevator.com/assets/img/bilbao_sketch.png", "The facade expresses the building's relationship with the urban context."],
188
+ ]
189
+
190
+ with gr.Blocks() as demo_step_by_step:
191
+ gr.Markdown("# Architecture Feedback Generator (Step-by-Step)")
192
+
193
+ with gr.Row():
194
+ image_input = gr.Image(type="pil", label="Upload Architectural Image")
195
+ text_input = gr.Textbox(label="Enter Text Description or Question")
196
+
197
+ classify_button = gr.Button("Perform Classification & Generate Prompt")
198
+ image_output_label = gr.Label(num_top_classes=len(class_names), label="Image Classification Results")
199
+ text_output_textbox = gr.Textbox(label="Text Classification Results")
200
+ text_classification_probabilities_state = gr.State()
201
+ prompt_output_textbox = gr.Textbox(label="Generated Prompt for LLM", interactive=True)
202
+ generate_feedback_button = gr.Button("Generate Feedback from Prompt")
203
+ llm_output_text = gr.Textbox(label="Generated Feedback")
204
+
205
+ def dynamic_generate_prompt(img_res, txt_prob, txt):
206
+ predicted_label = "1" if txt_prob.get("High Concept",0) > txt_prob.get("No High Concept",0) else "0"
207
+ return generate_prompt_only(img_res, txt_prob, predicted_label, txt)
208
+
209
+ classification_outputs = classify_button.click(
210
+ fn=perform_classification_and_format,
211
+ inputs=[image_input, text_input],
212
+ outputs=[image_output_label, text_classification_probabilities_state, text_output_textbox]
213
+ )
214
+
215
+ classification_outputs.then(
216
+ fn=dynamic_generate_prompt,
217
+ inputs=[image_output_label, text_classification_probabilities_state, text_input],
218
+ outputs=prompt_output_textbox
219
+ )
220
+
221
+ generate_feedback_button.click(
222
+ fn=generate_feedback_from_prompt,
223
+ inputs=[prompt_output_textbox],
224
+ outputs=llm_output_text
225
+ )
226
+
227
+ def generate_full_chain_output_step_by_step(img, txt):
228
+ img_res, txt_prob, txt_fmt = perform_classification_and_format(img, txt)
229
+ predicted_label = "1" if txt_prob.get("High Concept",0) > txt_prob.get("No High Concept",0) else "0"
230
+ prompt = generate_prompt_only(img_res, txt_prob, predicted_label, txt)
231
+ llm_res = generate_feedback_from_prompt(prompt)
232
+ return img_res, txt_fmt, prompt, llm_res
233
+
234
+ gr.Examples(
235
+ examples=examples,
236
+ inputs=[image_input, text_input],
237
+ outputs=[image_output_label, text_output_textbox, prompt_output_textbox, llm_output_text],
238
+ fn=generate_full_chain_output_step_by_step,
239
+ cache_examples=False
240
+ )
241
+
242
+ if __name__ == "__main__":
243
+ demo_step_by_step.launch()