orgoflu commited on
Commit
208dd23
ยท
verified ยท
1 Parent(s): 276cd92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -129
app.py CHANGED
@@ -1,144 +1,50 @@
1
- import nltk
2
- nltk.download("punkt")
3
-
4
  import gradio as gr
5
- import trafilatura
6
- import requests
7
- from markdownify import markdownify as md
8
- from sumy.parsers.plaintext import PlaintextParser
9
- from sumy.nlp.tokenizers import Tokenizer
10
- from sumy.summarizers.text_rank import TextRankSummarizer
11
- import re
12
  import torch
13
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
14
 
15
- # ===== ์‚ฌ์šฉํ•  ๋ชจ๋ธ 3๊ฐœ =====
16
  MODEL_OPTIONS = {
17
  "Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct",
18
  "Gemma-3-4B-it": "google/gemma-3-4b-it",
19
- "HyperCLOVA-X-Seed-3B": "naver-clova/HyperCLOVA-X-Seed-3B"
20
  }
21
 
22
  # ===== ๋ชจ๋ธ ๋กœ๋“œ =====
23
  def load_model(model_name):
24
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- torch_dtype=torch.float32,
28
- trust_remote_code=True
29
- ).to("cpu")
30
- return pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
31
-
32
- # ===== ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ =====
33
- def clean_text(text: str) -> str:
34
- return re.sub(r'\s+', ' ', text).strip()
35
-
36
- def remove_duplicates(sentences):
37
- seen, result = set(), []
38
- for s in sentences:
39
- s_clean = s.strip()
40
- if s_clean and s_clean not in seen:
41
- seen.add(s_clean)
42
- result.append(s_clean)
43
- return result
44
-
45
- # ===== ์ž๋™ ์š”์•ฝ =====
46
- def summarize_text(text):
47
- text = clean_text(text)
48
- length = len(text)
49
- if length < 300:
50
- sentence_count = 1
51
- elif length < 800:
52
- sentence_count = 2
53
- elif length < 1500:
54
- sentence_count = 3
55
  else:
56
- sentence_count = 4
57
-
58
- try:
59
- parser = PlaintextParser.from_string(text, Tokenizer("korean"))
60
- if len(parser.document.sentences) == 0:
61
- raise ValueError
62
- except:
63
- try:
64
- parser = PlaintextParser.from_string(text, Tokenizer("english"))
65
- if len(parser.document.sentences) == 0:
66
- raise ValueError
67
- except:
68
- sentences = re.split(r'(?<=[.!?])\s+', text)
69
- return sentences[:sentence_count]
70
-
71
- summarizer = TextRankSummarizer()
72
- summary_sentences = summarizer(parser.document, sentence_count)
73
- summary_list = [str(sentence) for sentence in summary_sentences]
74
- summary_list = remove_duplicates(summary_list)
75
- summary_list.sort(key=lambda s: text.find(s))
76
- return summary_list
77
-
78
- # ===== LLM ์žฌ์ž‘์„ฑ =====
79
- def rewrite_with_llm(sentences, model_choice):
80
- model_name = MODEL_OPTIONS[model_choice]
81
- llm_pipeline = load_model(model_name)
82
-
83
- joined_text = "\n".join(sentences)
84
- prompt = f"""๋‹ค์Œ ๋ฌธ์žฅ์„ ์˜๋ฏธ๋Š” ์œ ์ง€ํ•˜๋˜, ์›๋ฌธ์— ์—†๋Š” ๋‚ด์šฉ์€ ์ ˆ๋Œ€ ์ถ”๊ฐ€ํ•˜์ง€ ๋ง๊ณ ,
85
- ๋ฌธ์žฅ๋งŒ ๋” ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋ฐ”๊ฟ”์ฃผ์„ธ์š”. ๋‹ค๋ฅธ ์„ค๋ช…์ด๋‚˜ ๋ถ€์—ฐ ๋ฌธ์žฅ์€ ์“ฐ์ง€ ๋งˆ์„ธ์š”.
86
 
87
- ๋ฌธ์žฅ:
88
- {joined_text}
89
- """
90
- result = llm_pipeline(prompt, max_new_tokens=150, do_sample=False, temperature=0)
91
- return result[0]["generated_text"].replace(prompt, "").strip()
92
-
93
- # ===== ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ =====
94
- def extract_summarize_paraphrase(url, model_choice):
95
- headers = {"User-Agent": "Mozilla/5.0"}
96
- try:
97
- r = requests.get(url, headers=headers, timeout=10)
98
- r.raise_for_status()
99
-
100
- html_content = trafilatura.extract(
101
- r.text,
102
- output_format="html",
103
- include_tables=False,
104
- favor_recall=True
105
- )
106
-
107
- if not html_content:
108
- markdown_text = md(r.text, heading_style="ATX")
109
- else:
110
- markdown_text = md(html_content, heading_style="ATX")
111
-
112
- summary_sentences = summarize_text(markdown_text)
113
- if not summary_sentences:
114
- summary_sentences = ["์š”์•ฝ ์—†์Œ"]
115
-
116
- paraphrased_text = rewrite_with_llm(summary_sentences, model_choice)
117
-
118
- return (
119
- markdown_text or "๋ณธ๋ฌธ ์—†์Œ",
120
- "\n".join(summary_sentences),
121
- paraphrased_text
122
- )
123
 
124
- except Exception as e:
125
- return f"์—๋Ÿฌ ๋ฐœ์ƒ: {e}", "์š”์•ฝ ์—†์Œ", "์žฌ์ž‘์„ฑ ์—†์Œ"
 
 
 
126
 
127
- # ===== Gradio UI =====
128
- iface = gr.Interface(
129
- fn=extract_summarize_paraphrase,
130
- inputs=[
131
- gr.Textbox(label="URL ์ž…๋ ฅ", placeholder="https://example.com"),
132
- gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Qwen2.5-1.5B-Instruct", label="์žฌ์ž‘์„ฑ ๋ชจ๋ธ ์„ ํƒ")
133
- ],
134
- outputs=[
135
- gr.Markdown(label="์ถ”์ถœ๋œ ๋ณธ๋ฌธ"),
136
- gr.Textbox(label="์ž๋™ ์š”์•ฝ", lines=5),
137
- gr.Textbox(label="์ž๋™ ์žฌ์ž‘์„ฑ (LLM)", lines=5)
138
- ],
139
- title="ํ•œ๊ตญ์–ด ๋ณธ๋ฌธ ์ถ”์ถœ + ์ž๋™ ์š”์•ฝ + LLM ์žฌ์ž‘์„ฑ",
140
- description="Qwen 1.5B, Gemma 3 E4B, HyperCLOVA-X-Seed-3B ์ค‘ ์„ ํƒํ•˜์—ฌ ์žฌ์ž‘์„ฑ"
141
- )
142
 
143
- if __name__ == "__main__":
144
- iface.launch()
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
  import torch
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForVision2Seq
4
 
5
+ # ===== ๋ชจ๋ธ ๋ชฉ๋ก =====
6
  MODEL_OPTIONS = {
7
  "Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct",
8
  "Gemma-3-4B-it": "google/gemma-3-4b-it",
9
+ "CLOVA-Donut-CORDv2": "naver-clova-ix/donut-base-finetuned-cord-v2"
10
  }
11
 
12
  # ===== ๋ชจ๋ธ ๋กœ๋“œ =====
13
  def load_model(model_name):
14
+ if model_name == "naver-clova-ix/donut-base-finetuned-cord-v2":
15
+ # Vision2Seq ๋ชจ๋ธ ๋กœ๋“œ
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForVision2Seq.from_pretrained(model_name)
18
+ return pipeline("image-to-text", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  else:
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ torch_dtype=torch.float32,
25
+ trust_remote_code=True
26
+ ).to("cpu")
27
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
28
+
29
+ # ===== CLOVA ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ =====
30
+ def process_image_with_clova(image):
31
+ pipe = load_model("naver-clova-ix/donut-base-finetuned-cord-v2")
32
+ result = pipe(image)
33
+ return result[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # ===== Gradio UI =====
36
+ with gr.Blocks() as iface:
37
+ gr.Markdown("## Qwen / Gemma / CLOVA Donut ํ…Œ์ŠคํŠธ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ with gr.Tab("ํ…์ŠคํŠธ URL ์š”์•ฝ/์žฌ์ž‘์„ฑ"):
40
+ url_input = gr.Textbox(label="URL ์ž…๋ ฅ")
41
+ model_choice = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Qwen2.5-1.5B-Instruct")
42
+ output_text = gr.Textbox(label="์ถœ๋ ฅ")
43
+ # ์—ฌ๊ธฐ์— ๊ธฐ์กด URL ์ฒ˜๋ฆฌ ํ•จ์ˆ˜ ์—ฐ๊ฒฐ
44
 
45
+ with gr.Tab("CLOVA ์ด๋ฏธ์ง€ โ†’ ํ…์ŠคํŠธ"):
46
+ image_input = gr.Image(type="pil", label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
47
+ clova_output = gr.Textbox(label="์ธ์‹ ๊ฒฐ๊ณผ")
48
+ image_input.change(process_image_with_clova, inputs=image_input, outputs=clova_output)
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ iface.launch()