Hellowish commited on
Commit
724cf70
·
verified ·
1 Parent(s): ab2113f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -64
app.py CHANGED
@@ -1,68 +1,34 @@
1
  import gradio as gr
2
- import torch
3
- from PIL import Image
4
- import pickle
5
- from torchvision import transforms
6
- from your_model_file import ImageCaptionAttentionModel, prepare_model, pred_caption # 你的模型和輔助函式
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # ------------------------
9
- # 1️⃣ 設定裝置
10
- # ------------------------
11
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
-
13
- # ------------------------
14
- # 2️⃣ 載入詞表
15
- # ------------------------
16
- with open('voc_dict.pkl', 'rb') as f:
17
- voc_dict = pickle.load(f)
18
-
19
- encoder = lambda x: voc_dict[x]
20
- decoder = lambda x: [key for key, value in voc_dict.items() if value == x][0]
21
-
22
- # ------------------------
23
- # 3️⃣ 圖像前處理
24
- # ------------------------
25
- image_transform = transforms.Compose([
26
- transforms.Resize(256),
27
- transforms.CenterCrop(224),
28
- transforms.ToTensor(),
29
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
30
- std=[0.229, 0.224, 0.225])
31
- ])
32
-
33
- # ------------------------
34
- # 4️⃣ 載入模型
35
- # ------------------------
36
- voc_size = len(voc_dict)
37
- embed_dim = 256 # 根據你的訓練設定
38
- model = ImageCaptionAttentionModel(voc_size, embed_dim)
39
- model = prepare_model(model)
40
- model.load_state_dict(torch.load('image_caption_attention_model.pth', map_location=device))
41
- model.eval()
42
-
43
- # ------------------------
44
- # 5️⃣ Gradio 推理函式
45
- # ------------------------
46
- def generate_caption(image: Image.Image):
47
- # 前處理
48
- img_tensor = image_transform(image)
49
- # 生成 caption
50
- caption = pred_caption(model, img_tensor)
51
- return caption
52
-
53
- # ------------------------
54
- # 6️⃣ Gradio UI
55
- # ------------------------
56
- with gr.Blocks() as demo:
57
- gr.Markdown("# Image Captioning with Attention")
58
- image_input = gr.Image(type="pil", label="Upload Image")
59
- result = gr.Textbox(label="Generated Caption")
60
- run_button = gr.Button("Generate Caption")
61
-
62
- run_button.click(fn=generate_caption, inputs=image_input, outputs=result)
63
-
64
- # ------------------------
65
- # 7️⃣ 啟動介面
66
- # ------------------------
67
  if __name__ == "__main__":
68
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # 1. 載入 SQuAD v2.0 預訓練模型
5
+ qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
6
+
7
+ # 2. 定義處理邏輯
8
+ def predict(context, question):
9
+ if not context or not question:
10
+ return "請輸入文件內容與問題。"
11
+
12
+ # 執行問答
13
+ result = qa_model(question=question, context=context)
14
+
15
+ # 如果信心分數太低,回傳無法回答(SQuAD v2.0 特色)
16
+ if result['score'] < 0.05:
17
+ return "抱歉,在文件中找不到相關答案。"
18
+
19
+ return result['answer']
20
+
21
+ # 3. 建立 Gradio 網頁介面
22
+ demo = gr.Interface(
23
+ fn=predict,
24
+ inputs=[
25
+ gr.Textbox(lines=10, label="Context (文件內容)", placeholder="請貼上文件內容..."),
26
+ gr.Textbox(lines=2, label="Question (提問)", placeholder="請問這份文件關於什麼?")
27
+ ],
28
+ outputs=gr.Textbox(label="Model Answer (模型回答)"),
29
+ title="Case Study: Document QA System",
30
+ description="這是一個基於 SQuAD v2.0 訓練的模型,能根據提供的文本回答問題。"
31
+ )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  if __name__ == "__main__":
34
  demo.launch()