ID2223_Lab2 / app.py
Maoxt's picture
Update app.py
13bc5be verified
import gradio as gr
import time
import os
from llama_cpp import Llama # Import the necessary library
import numpy as np
# --- CONFIGURATION ---
# Define the paths to your uploaded GGUF files
GGUF_MODEL_PATH_1B = "./llama-3.2-1b-summary-q4_k_m.gguf"
GGUF_MODEL_PATH_3B = "./llama-3.2-3b-summary-q4_k_m.gguf"
# Define the Prompt template for summarization (using a simple instruction format)
SYSTEM_PROMPT = (
"You are an expert summarization bot. Your task is to provide a comprehensive "
"and concise summary of the user's document based on the requested length."
)
# ----------------------------------------------------
# 1. MODEL LOADING FUNCTION (Runs once on app startup)
# ----------------------------------------------------
def load_llm(model_path):
print(f"Attempting to load GGUF model: {model_path}...")
# Load the model using llama-cpp-python (n_gpu_layers=0 forces CPU usage)
# verbose=True shows loading status
try:
llm = Llama(
model_path=model_path,
n_gpu_layers=0, # Ensure it runs on CPU
n_ctx=2048, # Context window size
verbose=True
)
print(f"Successfully loaded model: {model_path}")
return llm
except Exception as e:
print(f"Error loading model {model_path}: {e}")
# In case of failure, return a placeholder function
return lambda prompt, **kwargs: {"choices": [{"text": f"Error: Model failed to load ({model_path}). Check logs. Error: {e}"}]}
# Load models globally so they are loaded only once at startup
llm_1b = load_llm(GGUF_MODEL_PATH_1B)
llm_3b = load_llm(GGUF_MODEL_PATH_3B)
# ----------------------------------------------------
# 2. CORE PROCESSING FUNCTION (GGUF Inference)
# ----------------------------------------------------
def generate_summary_and_compare(long_document, selected_model, summary_length):
# 1. Select the model and configuration
if "1B" in selected_model:
selected_llm = llm_1b
model_name_display = "Llama-3.2-1B (Faster)"
elif "3B" in selected_model:
selected_llm = llm_3b
model_name_display = "Llama-3.2-3B (Higher Quality)"
else:
return "Error: Invalid model selection.", ""
# 2. Build the instruction prompt
instruction = f"Please summarize the following document and keep the summary {summary_length}. Document: \n\n{long_document}"
# We use Llama 3 format for instruction
full_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
# 3. Run Inference and measure speed
start_time = time.time()
# Determine max_tokens based on length request (heuristic)
max_tokens = 250 if "Detailed" in summary_length else 100
try:
# Call the GGUF model's completion method
output = selected_llm(
full_prompt,
max_tokens=max_tokens,
stop=["<|eot_id|>"], # Stop sequence for Llama models
temperature=0.7,
echo=False,
# min_p=0.1 # Optional: Can improve output quality slightly
)
end_time = time.time()
total_latency = end_time - start_time
# Extract the text output
summary_output = output["choices"][0]["text"].strip()
except Exception as e:
total_latency = time.time() - start_time
summary_output = f"Inference Error on {model_name_display}. Error: {e}"
# 4. Generate Performance Report (Task 2 Output)
speed_report = f"Model: {model_name_display}\nTotal Latency: {total_latency:.2f} seconds\n(Used for A-grade speed/quality tradeoff analysis)"
return summary_output, speed_report
# ----------------------------------------------------
# 3. GRADIO INTERFACE DEFINITION (kept same as previous version)
# ----------------------------------------------------
with gr.Blocks(title="KTH ID2223 Lab 2: LLM Document Summarizer") as demo:
gr.Markdown(f"# 📚 LLM Document Summarizer & Model Comparison (KTH Lab 2)")
gr.Markdown(
"This tool demonstrates the summarization capability of a fine-tuned LLM. "
"Select a model and input a document. The speed comparison between 1B and 3B models on CPU fulfills the requirements for Task 2."
)
with gr.Row():
# Left Panel: User Input and Controls
with gr.Column(scale=1):
input_document = gr.Textbox(
lines=10,
label="Paste Long Document or Report Content",
placeholder="Paste the text you need summarized here..."
)
summary_control = gr.Radio(
["Concise (under 50 words)", "Detailed (under 200 words)"],
label="Select Summary Length Requirement",
value="Concise (under 50 words)"
)
model_selector = gr.Radio(
["Llama-3.2-1B (Faster)", "Llama-3.2-3B (Higher Quality)"],
label="Select Model for Comparison (Task 2)",
value="Llama-3.2-1B (Faster)"
)
process_button = gr.Button("Generate Summary & Compare Speed", variant="primary")
# Right Panel: Output and Performance Report
with gr.Column(scale=2):
output_summary = gr.Textbox(
label="Generated Document Summary",
lines=15,
interactive=False
)
performance_report = gr.Textbox(
label="Performance and Latency Report",
interactive=False,
lines=3
)
# Event Binding: Connect the button click to the processing function
process_button.click(
fn=generate_summary_and_compare,
inputs=[input_document, model_selector, summary_control],
outputs=[output_summary, performance_report]
)
demo.launch()