mxiean commited on
Commit
e7c586c
·
verified ·
1 Parent(s): 6e2c108

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -30
app.py CHANGED
@@ -1,36 +1,126 @@
1
  import gradio as gr
2
- # from transformers import pipeline
3
-
4
- # 初始化文本生成管道
5
- try:
6
- generator = pipeline("text-generation", model="distilgpt2")
7
- except Exception as e:
8
- print(f"模型加载失败: {e}")
9
- generator = None
10
-
11
- def generate_advice(style):
12
- if not generator:
13
- return "系统初始化失败,请检查transformers库是否安装", ""
14
-
15
- prompt = f"如何将Airbnb房间装修成{style}风格?请给出3条具体建议"
16
- try:
17
- result = generator(prompt, max_length=200)
18
- return result[0]["generated_text"], ""
19
- except Exception as e:
20
- return f"生成建议时出错: {str(e)}", ""
21
-
22
- # 创建简单界面
23
- with gr.Blocks() as demo:
24
- gr.Markdown("## 🏡 Airbnb装修助手")
25
- style_input = gr.Textbox(label="输入想要的风格")
26
- submit_btn = gr.Button("生成建议")
27
- advice_output = gr.Textbox(label="装修建议")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  submit_btn.click(
30
- fn=generate_advice,
31
- inputs=style_input,
32
- outputs=advice_output
33
  )
34
 
 
35
  if __name__ == "__main__":
36
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import numpy as np
7
+
8
+ # 初始化两个pipeline
9
+ item_classifier = pipeline(
10
+ "image-classification",
11
+ model="google/vit-base-patch16-224"
12
+ )
13
+
14
+ scene_analyzer = pipeline(
15
+ "image-to-text",
16
+ model="nlpconnect/vit-gpt2-image-captioning"
17
+ )
18
+
19
+ # 风格映射字典 (可根据实际需求扩展)
20
+ STYLE_MAPPING = {
21
+ "modern": "现代风格",
22
+ "contemporary": "当代风格",
23
+ "minimalist": "极简风格",
24
+ "industrial": "工业风格",
25
+ "scandinavian": "北欧风格",
26
+ "bohemian": "波西米亚风格",
27
+ "rustic": "乡村风格",
28
+ "traditional": "传统风格",
29
+ "coastal": "海岸风格",
30
+ "mid-century": "中世纪现代风格"
31
+ }
32
+
33
+ def determine_overall_style(items, description):
34
+ """
35
+ 根据物品分类和场景描述确定整体风格
36
+ """
37
+ # 从物品分类中提取关键词
38
+ item_keywords = " ".join([item['label'] for item in items])
39
+
40
+ # 合并所有文本信息
41
+ combined_text = f"{item_keywords} {description}".lower()
42
+
43
+ # 简单的关键词匹配确定风格
44
+ detected_styles = []
45
+ for style_en, style_cn in STYLE_MAPPING.items():
46
+ if style_en in combined_text:
47
+ detected_styles.append(style_cn)
48
+
49
+ # 如果没有匹配到任何风格,返回一个默认值
50
+ if not detected_styles:
51
+ return "混合风格"
52
+
53
+ # 返回匹配到的所有风格
54
+ return "、".join(detected_styles)
55
+
56
+ def analyze_room_style(image):
57
+ """
58
+ 分析房间风格的主函数
59
+ """
60
+ # 如果是URL,下载图片
61
+ if isinstance(image, str) and image.startswith(('http://', 'https://')):
62
+ response = requests.get(image)
63
+ image = Image.open(BytesIO(response.content))
64
+
65
+ # 物品级别分析 (取前5个最相关的物品)
66
+ item_results = item_classifier(image, top_k=5)
67
+
68
+ # 场景级别分析
69
+ scene_description = scene_analyzer(image)[0]['generated_text']
70
+
71
+ # 综合判断风格
72
+ style = determine_overall_style(item_results, scene_description)
73
+
74
+ # 格式化结果
75
+ items_formatted = "\n".join([f"- {item['label']} ({item['score']:.2f})" for item in item_results])
76
+
77
+ return {
78
+ "detected_items": items_formatted,
79
+ "scene_description": scene_description,
80
+ "predicted_style": style
81
+ }
82
+
83
+ def predict_style(image):
84
+ """
85
+ Gradio接口使用的预测函数
86
+ """
87
+ result = analyze_room_style(image)
88
+
89
+ output = f"""🏠 预测风格: {result['predicted_style']}
90
+
91
+ 📝 场景描述: {result['scene_description']}
92
+
93
+ 🛋️ 检测到的物品:
94
+ {result['detected_items']}
95
+ """
96
+ return output
97
+
98
+ # 创建Gradio界面
99
+ with gr.Blocks(title="Airbnb房屋风格识别") as demo:
100
+ gr.Markdown("# 🏡 Airbnb房屋风格识别")
101
+ gr.Markdown("上传您的房间照片,AI将分析您的房屋装饰风格")
102
+
103
+ with gr.Row():
104
+ with gr.Column():
105
+ image_input = gr.Image(type="filepath", label="上传房间照片")
106
+ submit_btn = gr.Button("分析风格")
107
+ with gr.Column():
108
+ output = gr.Textbox(label="分析结果", lines=10)
109
+
110
+ examples = gr.Examples(
111
+ examples=[
112
+ ["https://example.com/room1.jpg"], # 替换为实际示例图片URL
113
+ ["https://example.com/room2.jpg"] # 替换为实际示例图片URL
114
+ ],
115
+ inputs=image_input
116
+ )
117
 
118
  submit_btn.click(
119
+ fn=predict_style,
120
+ inputs=image_input,
121
+ outputs=output
122
  )
123
 
124
+ # 启动应用
125
  if __name__ == "__main__":
126
+ demo.launch(server_name="0.0.0.0", server_port=7860)