mxiean commited on
Commit
e900402
·
verified ·
1 Parent(s): 03e7c37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -19
app.py CHANGED
@@ -2,28 +2,42 @@ import gradio as gr
2
  from transformers import pipeline
3
  from datasets import load_dataset
4
 
5
- # 加载数据集
6
- dataset = load_dataset("AntZet/home_decoration_objects_images", streaming=True)
7
- dataset = dataset['train'].take(100) # 只取前100个样本加快加载
 
 
 
 
8
 
9
- # 初始化模型
10
- style_advisor = pipeline("text-generation", model="gpt2") # 使用较小的gpt2模型
 
 
 
 
11
 
12
  def get_advice(style):
13
- # 从数据集中找出匹配风格的例子
14
- examples = [ex for ex in dataset if ex['style'].lower() == style.lower()]
15
 
16
- if not examples:
17
- return "未找到该风格,请尝试:工业风、北欧风等", []
18
-
19
- # 生成建议
20
- prompt = f"如何将Airbnb房间装修成{style}风格?请给出3条具体建议"
21
- advice = style_advisor(prompt, max_length=150)[0]['generated_text']
22
-
23
- # 获取示例图片
24
- example_images = [ex['image'] for ex in examples[:3]]
25
-
26
- return advice, example_images
 
 
 
 
 
 
27
 
28
  # 创建界面
29
  with gr.Blocks() as demo:
@@ -41,4 +55,5 @@ with gr.Blocks() as demo:
41
  outputs=[advice_output, gallery]
42
  )
43
 
44
- demo.launch()
 
 
2
  from transformers import pipeline
3
  from datasets import load_dataset
4
 
5
+ # 加载数据集(使用更小的样本集)
6
+ try:
7
+ dataset = load_dataset("AntZet/home_decoration_objects_images", streaming=True)
8
+ dataset = dataset['train'].take(50) # 只取50个样本加快加载
9
+ except Exception as e:
10
+ print(f"加载数据集失败: {e}")
11
+ dataset = []
12
 
13
+ # 初始化模型(使用更小的模型)
14
+ try:
15
+ style_advisor = pipeline("text-generation", model="distilgpt2") # 改用更小的distilgpt2
16
+ except Exception as e:
17
+ print(f"加载模型失败: {e}")
18
+ style_advisor = None
19
 
20
  def get_advice(style):
21
+ if not dataset or not style_advisor:
22
+ return "系统初始化失败,请检查后台日志", []
23
 
24
+ try:
25
+ # 从数据集中找出匹配风格的例子
26
+ examples = [ex for ex in dataset if ex['style'].lower() == style.lower()]
27
+
28
+ if not examples:
29
+ return "未找到该风格,请尝试:工业风、北欧风等", []
30
+
31
+ # 生成建议
32
+ prompt = f"如何将Airbnb房间装修成{style}风格?请给出3条具体建议"
33
+ advice = style_advisor(prompt, max_length=150)[0]['generated_text']
34
+
35
+ # 获取示例图片
36
+ example_images = [ex['image'] for ex in examples[:3]]
37
+
38
+ return advice, example_images
39
+ except Exception as e:
40
+ return f"生成建议时出错: {str(e)}", []
41
 
42
  # 创建界面
43
  with gr.Blocks() as demo:
 
55
  outputs=[advice_output, gallery]
56
  )
57
 
58
+ if __name__ == "__main__":
59
+ demo.launch()