bQAsummary2 / app.py
anujkum0x's picture
Update app.py
3470c12 verified
import gradio as gr
import pandas as pd
import os
import re
import json
from collections import defaultdict
import requests
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login, HfFolder
import time
# Model configuration - Get token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", "") # Get from environment variable
USE_API = True # Default to API mode to avoid downloading model
# Set cache directory for local models
os.environ["TRANSFORMERS_CACHE"] = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
CACHE_DIR = os.environ["TRANSFORMERS_CACHE"]
# Create cache directory if it doesn't exist
os.makedirs(CACHE_DIR, exist_ok=True)
# Define model - only Mistral-7B-Instruct-v0.2
MISTRAL_MODEL = {
"name": "mistralai/Mistral-7B-Instruct-v0.3",
"url": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
"type": "instruct",
"prompt_format": lambda p: f"<s>[INST] {p} [/INST]</s>",
"max_length": 8000,
"max_input_length": 24000, # Limit input length
}
# For local model fallback
LOCAL_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
# Try to login with the token from environment variable
try:
if HF_TOKEN:
login(HF_TOKEN)
print("Successfully logged in with token from environment variable")
else:
print("No HF_TOKEN found in environment variables. Please set it before running.")
except Exception as e:
print(f"Error logging in with token: {e}")
def truncate_text(text, max_length):
"""Truncate text to a maximum length while preserving complete sentences."""
if len(text) <= max_length:
return text
# Find the last sentence boundary before max_length
last_boundary = max(text.rfind('.', 0, max_length),
text.rfind('!', 0, max_length),
text.rfind('?', 0, max_length))
if last_boundary == -1:
# If no sentence boundary found, just truncate at max_length
return text[:max_length]
return text[:last_boundary+1]
def generate_summary_api(transcript):
"""Generate summary using the Hugging Face Inference API with Mistral model."""
global HF_TOKEN
if not HF_TOKEN:
return {"Error": "No Hugging Face token found. Please set the HF_TOKEN environment variable."}
# Create a shorter prompt for instruction-based models to avoid token length issues
instruct_prompt = f"""Analyze this customer service call transcript and create a professional summary with these sections:
- Issue: Main reason for the call
- Verification: Identity verification steps
- Resolution: Actions agreed upon
- Outcome: Final resolution and next steps
Use professional language, focus on facts, and write in third person.
Transcript:
{transcript}
"""
# Try the Mistral model
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
model_name = MISTRAL_MODEL["name"]
model_url = MISTRAL_MODEL["url"]
prompt_formatter = MISTRAL_MODEL["prompt_format"]
max_input_length = MISTRAL_MODEL["max_input_length"]
try:
print(f"Using Mistral model: {model_name}")
# Truncate transcript if needed
truncated_transcript = truncate_text(transcript, max_input_length)
if len(truncated_transcript) < len(transcript):
print(f"Truncated transcript from {len(transcript)} to {len(truncated_transcript)} characters")
# Format prompt for Mistral
formatted_prompt = prompt_formatter(instruct_prompt)
data = {
"inputs": formatted_prompt,
"parameters": {
"max_new_tokens": 512,
"temperature": 0.1,
"top_p": 0.95,
"top_k": 40,
"repetition_penalty": 1.1,
"do_sample": True
},
"options": {
"use_cache": True,
"wait_for_model": True
}
}
# Add retry logic for API calls
max_retries = 3
retry_delay = 5
for attempt in range(max_retries):
try:
response = requests.post(model_url, headers=headers, json=data, timeout=120) # Increased timeout
# Check if response is HTML (error page)
if response.headers.get('content-type', '').startswith('text/html'):
print(f"Received HTML error response. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
continue
if response.status_code == 200:
break
elif response.status_code in [429, 503, 500]:
# Rate limit or server error, retry after delay
print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay * (attempt + 1)) # Exponential backoff
else:
# Other error, don't retry
break
except requests.exceptions.Timeout:
print(f"Request timed out. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay * (attempt + 1))
except Exception as e:
print(f"Request error: {e}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay * (attempt + 1))
if response.status_code == 200:
# Extract the generated text
response_json = response.json()
if isinstance(response_json, list) and len(response_json) > 0:
if "generated_text" in response_json[0]:
summary_text = response_json[0]["generated_text"]
else:
summary_text = str(response_json[0])
else:
summary_text = str(response_json)
# Remove the prompt from the generated text if it's included
if isinstance(summary_text, str) and summary_text.startswith(formatted_prompt):
summary_text = summary_text[len(formatted_prompt):].strip()
# Parse the summary into structured format
structured_summary = parse_summary(summary_text)
# Check if we got a valid summary (at least Issue and Resolution)
if structured_summary.get("Issue", "").strip() and structured_summary.get("Resolution", "").strip():
print(f"Successfully generated summary with {model_name}")
return structured_summary
else:
return {"Error": f"Model {model_name} returned incomplete summary. Falling back to rule-based summary."}
else:
print(f"Model {model_name} failed with status code {response.status_code}, falling back to rule-based summary.")
return rule_based_summary(transcript)
except Exception as e:
print(f"Error with model {model_name}: {e}")
return rule_based_summary(transcript)
def parse_summary(text):
"""Helper function to parse summary text into structured format."""
sections = ["Issue:", "Verification:", "Resolution:", "Outcome:"]
structured_summary = {}
for i, section in enumerate(sections):
start_idx = text.find(section)
if start_idx != -1:
start_idx += len(section)
end_idx = text.find(sections[i+1], start_idx) if i < len(sections) - 1 else len(text)
content = text[start_idx:end_idx].strip()
structured_summary[section[:-1]] = content
# Fill in any missing sections
for section in sections:
if section[:-1] not in structured_summary:
structured_summary[section[:-1]] = "Not provided in summary"
return structured_summary
def rule_based_summary(transcript):
"""
Generate a basic summary using rule-based extraction when model fails.
This is a fallback method that doesn't require model access.
"""
# Extract payment amounts
payment_pattern = r'\$(\d+\.\d+)'
payments = re.findall(payment_pattern, transcript)
# Extract dates
date_pattern = r'(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d+(?:st|nd|rd|th)?'
dates = re.findall(date_pattern, transcript, re.IGNORECASE)
# Check for verification-related phrases
verification_patterns = [
r'(?:verify|verification|confirm|security).*?(?:phone|number|date|birth|address)',
r'(?:could you please verify|for security|for verification).*?(?:\?|\.)'
]
verification_found = False
for pattern in verification_patterns:
if re.search(pattern, transcript, re.IGNORECASE):
verification_found = True
break
# Check for payment-related terms
payment_terms = ['payment', 'balance', 'due', 'bill', 'pay', 'account', 'charge', 'fee']
payment_related = any(term in transcript.lower() for term in payment_terms)
# Create a basic structured summary
summary = {
"Issue": "Customer called regarding account balance or payment" if payment_related else "Customer called for general inquiry or support",
"Verification": "Agent verified customer identity" if verification_found else "No explicit verification mentioned",
"Resolution": f"Discussion involved payment amount(s) of ${', $'.join(payments)}" if payments else "No specific payment resolution mentioned",
"Outcome": f"Next steps may involve date(s): {', '.join(dates)}" if dates else "Call completed with customer acknowledgment"
}
return summary
# Define the Gradio interface - simplified with only Generate Summary tab
with gr.Blocks() as iface:
gr.Markdown("## bQA Summary")
# Only one tab for Generate Summary
with gr.Row():
with gr.Column(scale=3):
transcript_input = gr.Textbox(
label="Call Transcript",
placeholder="Paste your customer service call transcript here...",
lines=10
)
summary_button = gr.Button("Generate Summary", variant="primary")
summary_output = gr.JSON(label="Structured Summary")
summary_button.click(
lambda transcript: generate_summary_api(transcript),
inputs=[transcript_input],
outputs=[summary_output]
)
# Launch the Gradio app
iface.launch()