Spaces:
Running
on
Zero
Running
on
Zero
| import subprocess | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
| shell=True) | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| ) | |
| import docx | |
| import PyPDF2 | |
| import spaces | |
| def convert_to_txt(file): | |
| doc_type = file.split(".")[-1].strip() | |
| if doc_type in ["txt", "md", "py"]: | |
| data = [file.read().decode("utf-8")] | |
| elif doc_type in ["pdf"]: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| data = [ | |
| pdf_reader.pages[i].extract_text() for i in range(len(pdf_reader.pages)) | |
| ] | |
| elif doc_type in ["docx"]: | |
| doc = docx.Document(file) | |
| data = [p.text for p in doc.paragraphs] | |
| else: | |
| raise gr.Error(f"ERROR: unsupported document type: {doc_type}") | |
| text = "\n\n".join(data) | |
| return text | |
| model_name = "THUDM/LongCite-glm4-9b" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device="cuda", | |
| attn_implementation="flash_attention_2", | |
| ) | |
| html_styles = """<style> | |
| .reference { | |
| color: blue; | |
| text-decoration: underline; | |
| } | |
| .highlight { | |
| background-color: yellow; | |
| } | |
| .label { | |
| font-family: sans-serif; | |
| font-size: 16px; | |
| font-weight: bold; | |
| } | |
| .Bold { | |
| font-weight: bold; | |
| } | |
| .statement { | |
| background-color: lightgrey; | |
| } | |
| </style>\n""" | |
| def process_text(text): | |
| special_char = { | |
| "&": "&", | |
| "'": "'", | |
| '"': """, | |
| "<": "<", | |
| ">": ">", | |
| "\n": "<br>", | |
| } | |
| for x, y in special_char.items(): | |
| text = text.replace(x, y) | |
| return text | |
| def convert_to_html(statements, clicked=-1): | |
| html = html_styles + '<br><span class="label">Answer:</span><br>\n' | |
| all_cite_html = [] | |
| clicked_cite_html = None | |
| cite_num2idx = {} | |
| idx = 0 | |
| for i, js in enumerate(statements): | |
| statement, citations = process_text(js["statement"]), js["citation"] | |
| if clicked == i: | |
| html += f"""<span class="statement">{statement}</span>""" | |
| else: | |
| html += f"<span>{statement}</span>" | |
| if citations: | |
| cite_html = [] | |
| idxs = [] | |
| for c in citations: | |
| idx += 1 | |
| idxs.append(str(idx)) | |
| cite = ( | |
| "[Sentence: {}-{}\t|\tChar: {}-{}]<br>\n<span {}>{}</span>".format( | |
| c["start_sentence_idx"], | |
| c["end_sentence_idx"], | |
| c["start_char_idx"], | |
| c["end_char_idx"], | |
| 'class="highlight"' if clicked == i else "", | |
| process_text(c["cite"].strip()), | |
| ) | |
| ) | |
| cite_html.append( | |
| f"""<span><span class="Bold">Snippet [{idx}]:</span><br>{cite}</span>""" | |
| ) | |
| all_cite_html.extend(cite_html) | |
| cite_num = "[{}]".format(",".join(idxs)) | |
| cite_num2idx[cite_num] = i | |
| cite_num_html = """ <span class="reference" style="color: blue" id={}>{}</span>""".format( | |
| i, cite_num | |
| ) | |
| html += cite_num_html | |
| html += "\n" | |
| if clicked == i: | |
| clicked_cite_html = ( | |
| html_styles | |
| + """<br><span class="label">Citations of current statement:</span><br><div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format( | |
| "<br><br>\n".join(cite_html) | |
| ) | |
| ) | |
| all_cite_html = ( | |
| html_styles | |
| + """<br><span class="label">All citations:</span><br>\n<div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format( | |
| "<br><br>\n".join(all_cite_html).replace( | |
| '<span class="highlight">', "<span>" | |
| ) | |
| if len(all_cite_html) | |
| else "No citation in the answer" | |
| ) | |
| ) | |
| return html, all_cite_html, clicked_cite_html, cite_num2idx | |
| def render_context(file): | |
| if hasattr(file, "name"): | |
| context = convert_to_txt(file.name) | |
| return gr.Textbox(context, visible=True) | |
| else: | |
| raise gr.Error(f"ERROR: no uploaded document") | |
| def infer(context, query): | |
| return model.query_longcite( | |
| context=context, | |
| query=query, | |
| tokenizer=tokenizer, | |
| max_input_length=128000, | |
| max_new_tokens=1024, | |
| ) | |
| def run_llm(context, query): | |
| if not context: | |
| raise gr.Error("Error: no uploaded document") | |
| if not query: | |
| raise gr.Error("Error: no query") | |
| result = infer(context=context, query=query) | |
| all_statements = result["all_statements"] | |
| answer_html, all_cite_html, clicked_cite_html, cite_num2idx_dict = convert_to_html( | |
| all_statements | |
| ) | |
| cite_nums = list(cite_num2idx_dict.keys()) | |
| return { | |
| statements: gr.JSON(all_statements), | |
| answer: gr.HTML(answer_html, visible=True), | |
| all_citations: gr.HTML(all_cite_html, visible=True), | |
| cite_num2idx: gr.JSON(cite_num2idx_dict), | |
| citation_choices: gr.Radio(cite_nums, visible=len(cite_nums) > 0), | |
| clicked_citations: gr.HTML(visible=False), | |
| } | |
| def chose_citation(statements, cite_num2idx, clicked_cite_num): | |
| clicked = cite_num2idx[clicked_cite_num] | |
| answer_html, _, clicked_cite_html, _ = convert_to_html(statements, clicked=clicked) | |
| return { | |
| answer: gr.HTML(answer_html, visible=True), | |
| clicked_citations: gr.HTML(clicked_cite_html, visible=True), | |
| } | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
| LongCite-glm4-9b Huggingface Space🤗 | |
| </div> | |
| <div style="text-align: center;"> | |
| <a href="https://huggingface.co/THUDM/LongCite-glm4-9b">🤗 Model Hub</a> | | |
| <a href="https://github.com/THUDM/LongCite">🌐 Github</a> | | |
| <a href="https://arxiv.org/abs/2409.02897">📜 arxiv </a> | |
| </div> | |
| <br> | |
| <div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;"> | |
| If you plan to use it long-term, please consider deploying the model or forking this space yourself. | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| file = gr.File( | |
| label="Upload a document (supported type: pdf, docx, txt, md, py)" | |
| ) | |
| query = gr.Textbox(label="Question") | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(scale=4): | |
| context = gr.Textbox( | |
| label="Document content", | |
| autoscroll=False, | |
| placeholder="No uploaded document.", | |
| max_lines=10, | |
| visible=False, | |
| ) | |
| file.upload(render_context, [file], [context]) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| statements = gr.JSON(label="statements", visible=False) | |
| answer = gr.HTML(label="Answer", visible=True) | |
| cite_num2idx = gr.JSON(label="cite_num2idx", visible=False) | |
| citation_choices = gr.Radio( | |
| label="Chose citations for details", visible=False, interactive=True | |
| ) | |
| with gr.Column(scale=4): | |
| clicked_citations = gr.HTML( | |
| label="Citations of the chosen statement", visible=False | |
| ) | |
| all_citations = gr.HTML(label="All citations", visible=False) | |
| submit_btn.click( | |
| run_llm, | |
| [context, query], | |
| [ | |
| statements, | |
| answer, | |
| all_citations, | |
| cite_num2idx, | |
| citation_choices, | |
| clicked_citations, | |
| ], | |
| ) | |
| citation_choices.change( | |
| chose_citation, | |
| [statements, cite_num2idx, citation_choices], | |
| [answer, clicked_citations], | |
| ) | |
| demo.queue() | |
| demo.launch() | |