Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| REPO_ID = os.environ.get("HF_MODEL_ID", "HuyTran1301/constrative_cont_so_phase2_SI") | |
| MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512")) | |
| GEN_MAX_LENGTH = int(os.environ.get("GEN_MAX_LENGTH", "64")) | |
| torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "1"))) | |
| tokenizer = AutoTokenizer.from_pretrained(REPO_ID) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(REPO_ID) | |
| def summarize_one(lang: str, desc: str, code: str): | |
| if not any([lang.strip(), desc.strip(), code.strip()]): | |
| return pd.DataFrame([["", ""]], columns=["#","Summary"]) | |
| merged_text = f"{lang.strip()}: {desc.strip()} <code> {code.strip()}" | |
| input_ids = tokenizer( | |
| merged_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_LENGTH | |
| ).input_ids | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=GEN_MAX_LENGTH, | |
| num_beams=5, | |
| num_return_sequences=5, | |
| min_length=4, | |
| length_penalty=0.0 | |
| ) | |
| summaries = [tokenizer.decode(o, skip_special_tokens=True).strip() for o in outputs] | |
| df = pd.DataFrame(list(enumerate(summaries, start=1)), columns=["#", "Summary"]) | |
| return df | |
| with gr.Blocks(title="Code Summarization") as demo: | |
| gr.Markdown("# Code Summarization") | |
| with gr.Row(): | |
| lang = gr.Textbox(label="Language", placeholder="e.g., Python, Java, etc.") | |
| desc = gr.Textbox(label="Description", placeholder="What does the code do?") | |
| code = gr.Textbox(lines=8, label="Code", placeholder="Paste your code here...") | |
| btn = gr.Button("Generate Summaries") | |
| out_table = gr.Dataframe(headers=["#", "Summary"], label="Generated Summaries", interactive=False) | |
| btn.click( | |
| summarize_one, | |
| inputs=[lang, desc, code], | |
| outputs=[out_table], | |
| api_name="predict" | |
| ) | |
| gr.Markdown(f"**Model:** `{REPO_ID}` • **Input max length:** {MAX_LENGTH} • **Output max length:** {GEN_MAX_LENGTH} • **num_beams:** 5") | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |