mxiean commited on
Commit
92ae4a5
·
verified ·
1 Parent(s): cbfb2ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -13
app.py CHANGED
@@ -1,14 +1,111 @@
1
- # 测试文本模型
2
- # import part
3
- import streamlit as st
4
- from transformers import pipeline
5
  import torch
6
- # pipe = pipeline("text-classification", model="distilbert-base-uncased")
7
- # print(pipe("This is a test sentence."))
8
-
9
- # # 测试视觉模型
10
- # from PIL import Image
11
- # import requests
12
- # pipe = pipeline("image-classification", model="google/vit-base-patch16-224")
13
- # url = "http://images.cocodataset.org/val2017/000000039769.jpg"
14
- # print(pipe(Image.open(requests.get(url, stream=True).raw)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
2
  import torch
3
+ from transformers import pipeline
4
+ from datasets import load_dataset
5
+ from PIL import Image
6
+ import numpy as np
7
+ from collections import Counter
8
+ import functools
9
+
10
+ # 使用标准库的缓存装饰器替代Gradio缓存
11
+ @functools.lru_cache(maxsize=None)
12
+ def load_models():
13
+ return {
14
+ "detector": pipeline(
15
+ "object-detection",
16
+ model="facebook/detr-resnet-50",
17
+ device=0 if torch.cuda.is_available() else -1
18
+ ),
19
+ "generator": pipeline(
20
+ "text2text-generation",
21
+ model="google/flan-t5-base", # 改用基础版降低资源需求
22
+ device=0 if torch.cuda.is_available() else -1
23
+ )
24
+ }
25
+
26
+ # 数据集加载函数(移除Gradio缓存)
27
+ def load_dataset_data():
28
+ ds = load_dataset("AntZet/home_decoration_objects_images")
29
+ return ds['train'].to_pandas()
30
+
31
+ # 颜色分析函数保持不变
32
+ def get_dominant_colors(img, n_colors=3):
33
+ arr = np.array(img.resize((100,100)))
34
+ pixels = arr.reshape(-1,3)
35
+ from sklearn.cluster import KMeans
36
+ kmeans = KMeans(n_clusters=n_colors)
37
+ kmeans.fit(pixels)
38
+ return [f"#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}" for c in kmeans.cluster_centers_]
39
+
40
+ # 核心处理函数
41
+ def generate_recommendation(target_style):
42
+ try:
43
+ models = load_models()
44
+ df = load_dataset_data()
45
+
46
+ style_df = df[df['style'] == target_style.lower()]
47
+ if len(style_df) < 3:
48
+ return f"⚠️ Not enough samples for {target_style} style"
49
+
50
+ sample_images = style_df.sample(5)['image']
51
+
52
+ all_objects = []
53
+ color_palette = []
54
+
55
+ for img in sample_images:
56
+ detected = models["detector"](img)
57
+ all_objects += [obj['label'] for obj in detected if obj['score'] > 0.9]
58
+ color_palette += get_dominant_colors(img)
59
+
60
+ top_objects = Counter(all_objects).most_common(3)
61
+ top_colors = Counter(color_palette).most_common(3)
62
+
63
+ prompt = f"""Create interior design recommendations for {target_style} style:
64
+ Key objects: {[o[0] for o in top_objects]}
65
+ Color palette: {[c[0] for c in top_colors]}
66
+ Include: 3 essentials, 2 budget tips, common mistakes"""
67
+
68
+ advice = models["generator"](prompt, max_length=300)[0]['generated_text']
69
+
70
+ output = f"## 🎨 {target_style.title()} Style Guide\n\n"
71
+ output += "### 🪑 Key Objects\n" + "\n".join(
72
+ [f"- {o[0]} ({o[1]}x)" for o in top_objects]) + "\n\n"
73
+ output += "### 🎨 Colors\n" + "\n".join(
74
+ [f"<span style='color:{c[0]};'>■</span> {c[0]}" for c in top_colors]) + "\n\n"
75
+ output += "### 💡 Advice\n" + advice.replace(". ", ".\n")
76
+
77
+ return output
78
+
79
+ except Exception as e:
80
+ return f"❌ Error: {str(e)}"
81
+
82
+ # Gradio界面保持不变
83
+ with gr.Blocks(title="Design Assistant") as demo:
84
+ gr.Markdown("# 🏡 AI Design Advisor")
85
+
86
+ with gr.Row():
87
+ style_input = gr.Dropdown(
88
+ label="Select Style",
89
+ choices=["Industrial", "Scandinavian", "Bohemian", "Modern"],
90
+ value="Industrial"
91
+ )
92
+
93
+ submit_btn = gr.Button("Generate Plan", variant="primary")
94
+
95
+ with gr.Row():
96
+ output = gr.Markdown()
97
+ gallery = gr.Gallery(
98
+ label="Examples",
99
+ object_fit="contain",
100
+ height="300px"
101
+ )
102
+
103
+ def update_gallery(style):
104
+ df = load_dataset_data()
105
+ return df[df['style'] == style.lower()].sample(3)['image'].tolist()
106
+
107
+ style_input.change(update_gallery, inputs=style_input, outputs=gallery)
108
+ submit_btn.click(generate_recommendation, inputs=style_input, outputs=output)
109
+
110
+ if __name__ == "__main__":
111
+ demo.launch()