suncongcong commited on
Commit
a143c1a
·
verified ·
1 Parent(s): 4b92419

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -98
app.py CHANGED
@@ -1,98 +1,51 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModel, AutoImageProcessor
4
- from PIL import Image
5
-
6
- # --- 1. 配置模型和设备 ---
7
-
8
- # 您的模型在 Hub 上的 ID
9
- repo_id = "suncongcong/AST_DeRain"
10
-
11
- # 自动选择使用 GPU 还是 CPU (Hugging Face Spaces 会提供免费的 CPU 或 T4 GPU)
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- print(f"正在使用的设备: {device}")
14
-
15
- # --- 2. 加载模型和图像处理器 ---
16
-
17
- # 加载图像处理器
18
- processor = AutoImageProcessor.from_pretrained(repo_id)
19
-
20
- # 加载模型
21
- # trust_remote_code=True 是必须的,因为它需要执行您上传的 modeling_ast.py 文件
22
- model = AutoModel.from_pretrained(
23
- repo_id,
24
- trust_remote_code=True
25
- ).to(device)
26
-
27
- print("模型加载成功,准备就绪!")
28
-
29
-
30
- # --- 3. 定义核心处理函数 ---
31
-
32
- def derain_image(input_image: Image.Image):
33
- """
34
- 接收一个 PIL Image 对象,返回一个处理后的 PIL Image 对象。
35
- """
36
- # 确保输入是 RGB 格式
37
- image = input_image.convert("RGB")
38
-
39
- # 使用处理器将图片转换为模型所需的 Tensor 格式
40
- inputs = processor(images=image, return_tensors="pt").to(device)
41
-
42
- # 在不计算梯度的模式下进行推理,以节省显存和提高速度
43
- with torch.no_grad():
44
- outputs = model(**inputs)
45
-
46
- # outputs[0] 就是我们得到的恢复后的图像 Tensor
47
- restored_tensor = outputs[0] if isinstance(outputs, tuple) else outputs
48
-
49
- # 将像素值限制在 0-1 之间
50
- restored_tensor = torch.clamp(restored_tensor, 0, 1)
51
-
52
- # 将 Tensor 转换回 PIL Image 以便显示
53
- restored_image = processor.post_process_image_to_image(restored_tensor.cpu())[0]
54
-
55
- return restored_image
56
-
57
-
58
- # --- 4. 创建并启动 Gradio 界面 ---
59
-
60
- # 使用 gr.Blocks() 可以更灵活地设计界面布局
61
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
62
- gr.Markdown(
63
- """
64
- # 🖼️ AST 图像去雨模型在线演示
65
- 上传一张带雨的图片,模型将会自动去除雨水痕迹。
66
- 模型仓库地址: [suncongcong/AST_DeRain](https://huggingface.co/suncongcong/AST_DeRain)
67
- """
68
- )
69
-
70
- with gr.Row():
71
- # 定义输入和输出组件
72
- input_img = gr.Image(type="pil", label="输入带雨图片 (Input Rainy Image)")
73
- output_img = gr.Image(type="pil", label="输出清晰图片 (Output Deraided Image)")
74
-
75
- # 定义按钮
76
- submit_btn = gr.Button("开始去雨 (Start Deraining)", variant="primary")
77
-
78
- # 设置按钮的点击事件
79
- submit_btn.click(
80
- fn=derain_image, # 按钮点击时调用的函数
81
- inputs=input_img, # 函数的输入来自哪个组件
82
- outputs=output_img # 函数的输出显示在哪个组件
83
- )
84
-
85
- # 添加一些示例图片,让用户可以快速体验
86
- gr.Examples(
87
- examples=[
88
- ["https://i.imgur.com/a4y39hV.jpg"],
89
- ["https://i.imgur.com/KxYE1v4.jpg"]
90
- ],
91
- inputs=input_img,
92
- outputs=output_img,
93
- fn=derain_image,
94
- cache_examples=True # 缓存示例结果,加快加载速度
95
- )
96
-
97
- # 启动应用
98
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import CLIPImageProcessor
4
+ from modeling_ast import ASTForRestoration
5
+ from PIL import Image
6
+ import requests
7
+ from io import BytesIO
8
+ from torchvision.transforms.functional import to_pil_image
9
+
10
+ # --- 1. 配置模型和设备 ---
11
+ repo_id = "suncongcong/AST_DeRain"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"正在使用的设备: {device}")
14
+
15
+ # --- 2. 加载模型和图像处理器 ---
16
+ print(f"正在从 '{repo_id}' 加载模型和处理器...")
17
+ processor = CLIPImageProcessor.from_pretrained(repo_id)
18
+ processor.size = {"height": 256, "width": 256}
19
+ processor.crop_size = {"height": 256, "width": 256}
20
+ print(f"图像处理器尺寸已强制设置为: {processor.size}")
21
+ model = ASTForRestoration.from_pretrained(
22
+ repo_id,
23
+ trust_remote_code=True
24
+ ).to(device)
25
+ print("✅ 模型加载成功,准备就绪!")
26
+
27
+ # --- 3. 定义核心处理函数 ---
28
+ def derain_image(input_image: Image.Image):
29
+ if input_image is None: return None
30
+ image = input_image.convert("RGB")
31
+ inputs = processor(images=image, return_tensors="pt").to(device)
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+ restored_tensor = outputs[0] if isinstance(outputs, tuple) else outputs
35
+ restored_tensor = torch.clamp(restored_tensor, 0, 1)
36
+ restored_image = to_pil_image(restored_tensor.cpu().squeeze(0))
37
+ return restored_image
38
+
39
+ # --- 4. 创建并启动 Gradio 界面 ---
40
+ print("正在创建 Gradio Interface...")
41
+ demo = gr.Interface(
42
+ fn=derain_image,
43
+ inputs=gr.Image(type="pil", label="输入带雨图片 (Input Rainy Image)"),
44
+ outputs=gr.Image(type="pil", label="输出清晰图片 (Output Deraided Image)"),
45
+ title="AST 图像去雨模型在线演示",
46
+ description="上传一张带雨的图片,模型将会自动去除雨水痕迹。模型仓库地址: [suncongcong/AST_DeRain](https://huggingface.co/suncongcong/AST_DeRain)"
47
+ )
48
+
49
+ # --- 最终修正:添加 server_name 参数以适应容器环境 ---
50
+ print("正在启动 Demo...")
51
+ demo.launch(server_name="0.0.0.0")