Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import os | |
| import re | |
| import json | |
| from collections import defaultdict | |
| import requests | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from huggingface_hub import login, HfFolder | |
| import time | |
| # Model configuration - Get token from environment variable | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") # Get from environment variable | |
| USE_API = True # Default to API mode to avoid downloading model | |
| # Set cache directory for local models | |
| os.environ["TRANSFORMERS_CACHE"] = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") | |
| CACHE_DIR = os.environ["TRANSFORMERS_CACHE"] | |
| # Create cache directory if it doesn't exist | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # Define model - only Mistral-7B-Instruct-v0.2 | |
| MISTRAL_MODEL = { | |
| "name": "mistralai/Mistral-7B-Instruct-v0.3", | |
| "url": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2", | |
| "type": "instruct", | |
| "prompt_format": lambda p: f"<s>[INST] {p} [/INST]</s>", | |
| "max_length": 8000, | |
| "max_input_length": 24000, # Limit input length | |
| } | |
| # For local model fallback | |
| LOCAL_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" | |
| # Try to login with the token from environment variable | |
| try: | |
| if HF_TOKEN: | |
| login(HF_TOKEN) | |
| print("Successfully logged in with token from environment variable") | |
| else: | |
| print("No HF_TOKEN found in environment variables. Please set it before running.") | |
| except Exception as e: | |
| print(f"Error logging in with token: {e}") | |
| def truncate_text(text, max_length): | |
| """Truncate text to a maximum length while preserving complete sentences.""" | |
| if len(text) <= max_length: | |
| return text | |
| # Find the last sentence boundary before max_length | |
| last_boundary = max(text.rfind('.', 0, max_length), | |
| text.rfind('!', 0, max_length), | |
| text.rfind('?', 0, max_length)) | |
| if last_boundary == -1: | |
| # If no sentence boundary found, just truncate at max_length | |
| return text[:max_length] | |
| return text[:last_boundary+1] | |
| def generate_summary_api(transcript): | |
| """Generate summary using the Hugging Face Inference API with Mistral model.""" | |
| global HF_TOKEN | |
| if not HF_TOKEN: | |
| return {"Error": "No Hugging Face token found. Please set the HF_TOKEN environment variable."} | |
| # Create a shorter prompt for instruction-based models to avoid token length issues | |
| instruct_prompt = f"""Analyze this customer service call transcript and create a professional summary with these sections: | |
| - Issue: Main reason for the call | |
| - Verification: Identity verification steps | |
| - Resolution: Actions agreed upon | |
| - Outcome: Final resolution and next steps | |
| Use professional language, focus on facts, and write in third person. | |
| Transcript: | |
| {transcript} | |
| """ | |
| # Try the Mistral model | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| model_name = MISTRAL_MODEL["name"] | |
| model_url = MISTRAL_MODEL["url"] | |
| prompt_formatter = MISTRAL_MODEL["prompt_format"] | |
| max_input_length = MISTRAL_MODEL["max_input_length"] | |
| try: | |
| print(f"Using Mistral model: {model_name}") | |
| # Truncate transcript if needed | |
| truncated_transcript = truncate_text(transcript, max_input_length) | |
| if len(truncated_transcript) < len(transcript): | |
| print(f"Truncated transcript from {len(transcript)} to {len(truncated_transcript)} characters") | |
| # Format prompt for Mistral | |
| formatted_prompt = prompt_formatter(instruct_prompt) | |
| data = { | |
| "inputs": formatted_prompt, | |
| "parameters": { | |
| "max_new_tokens": 512, | |
| "temperature": 0.1, | |
| "top_p": 0.95, | |
| "top_k": 40, | |
| "repetition_penalty": 1.1, | |
| "do_sample": True | |
| }, | |
| "options": { | |
| "use_cache": True, | |
| "wait_for_model": True | |
| } | |
| } | |
| # Add retry logic for API calls | |
| max_retries = 3 | |
| retry_delay = 5 | |
| for attempt in range(max_retries): | |
| try: | |
| response = requests.post(model_url, headers=headers, json=data, timeout=120) # Increased timeout | |
| # Check if response is HTML (error page) | |
| if response.headers.get('content-type', '').startswith('text/html'): | |
| print(f"Received HTML error response. Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| continue | |
| if response.status_code == 200: | |
| break | |
| elif response.status_code in [429, 503, 500]: | |
| # Rate limit or server error, retry after delay | |
| print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay * (attempt + 1)) # Exponential backoff | |
| else: | |
| # Other error, don't retry | |
| break | |
| except requests.exceptions.Timeout: | |
| print(f"Request timed out. Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay * (attempt + 1)) | |
| except Exception as e: | |
| print(f"Request error: {e}. Retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay * (attempt + 1)) | |
| if response.status_code == 200: | |
| # Extract the generated text | |
| response_json = response.json() | |
| if isinstance(response_json, list) and len(response_json) > 0: | |
| if "generated_text" in response_json[0]: | |
| summary_text = response_json[0]["generated_text"] | |
| else: | |
| summary_text = str(response_json[0]) | |
| else: | |
| summary_text = str(response_json) | |
| # Remove the prompt from the generated text if it's included | |
| if isinstance(summary_text, str) and summary_text.startswith(formatted_prompt): | |
| summary_text = summary_text[len(formatted_prompt):].strip() | |
| # Parse the summary into structured format | |
| structured_summary = parse_summary(summary_text) | |
| # Check if we got a valid summary (at least Issue and Resolution) | |
| if structured_summary.get("Issue", "").strip() and structured_summary.get("Resolution", "").strip(): | |
| print(f"Successfully generated summary with {model_name}") | |
| return structured_summary | |
| else: | |
| return {"Error": f"Model {model_name} returned incomplete summary. Falling back to rule-based summary."} | |
| else: | |
| print(f"Model {model_name} failed with status code {response.status_code}, falling back to rule-based summary.") | |
| return rule_based_summary(transcript) | |
| except Exception as e: | |
| print(f"Error with model {model_name}: {e}") | |
| return rule_based_summary(transcript) | |
| def parse_summary(text): | |
| """Helper function to parse summary text into structured format.""" | |
| sections = ["Issue:", "Verification:", "Resolution:", "Outcome:"] | |
| structured_summary = {} | |
| for i, section in enumerate(sections): | |
| start_idx = text.find(section) | |
| if start_idx != -1: | |
| start_idx += len(section) | |
| end_idx = text.find(sections[i+1], start_idx) if i < len(sections) - 1 else len(text) | |
| content = text[start_idx:end_idx].strip() | |
| structured_summary[section[:-1]] = content | |
| # Fill in any missing sections | |
| for section in sections: | |
| if section[:-1] not in structured_summary: | |
| structured_summary[section[:-1]] = "Not provided in summary" | |
| return structured_summary | |
| def rule_based_summary(transcript): | |
| """ | |
| Generate a basic summary using rule-based extraction when model fails. | |
| This is a fallback method that doesn't require model access. | |
| """ | |
| # Extract payment amounts | |
| payment_pattern = r'\$(\d+\.\d+)' | |
| payments = re.findall(payment_pattern, transcript) | |
| # Extract dates | |
| date_pattern = r'(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d+(?:st|nd|rd|th)?' | |
| dates = re.findall(date_pattern, transcript, re.IGNORECASE) | |
| # Check for verification-related phrases | |
| verification_patterns = [ | |
| r'(?:verify|verification|confirm|security).*?(?:phone|number|date|birth|address)', | |
| r'(?:could you please verify|for security|for verification).*?(?:\?|\.)' | |
| ] | |
| verification_found = False | |
| for pattern in verification_patterns: | |
| if re.search(pattern, transcript, re.IGNORECASE): | |
| verification_found = True | |
| break | |
| # Check for payment-related terms | |
| payment_terms = ['payment', 'balance', 'due', 'bill', 'pay', 'account', 'charge', 'fee'] | |
| payment_related = any(term in transcript.lower() for term in payment_terms) | |
| # Create a basic structured summary | |
| summary = { | |
| "Issue": "Customer called regarding account balance or payment" if payment_related else "Customer called for general inquiry or support", | |
| "Verification": "Agent verified customer identity" if verification_found else "No explicit verification mentioned", | |
| "Resolution": f"Discussion involved payment amount(s) of ${', $'.join(payments)}" if payments else "No specific payment resolution mentioned", | |
| "Outcome": f"Next steps may involve date(s): {', '.join(dates)}" if dates else "Call completed with customer acknowledgment" | |
| } | |
| return summary | |
| # Define the Gradio interface - simplified with only Generate Summary tab | |
| with gr.Blocks() as iface: | |
| gr.Markdown("## bQA Summary") | |
| # Only one tab for Generate Summary | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| transcript_input = gr.Textbox( | |
| label="Call Transcript", | |
| placeholder="Paste your customer service call transcript here...", | |
| lines=10 | |
| ) | |
| summary_button = gr.Button("Generate Summary", variant="primary") | |
| summary_output = gr.JSON(label="Structured Summary") | |
| summary_button.click( | |
| lambda transcript: generate_summary_api(transcript), | |
| inputs=[transcript_input], | |
| outputs=[summary_output] | |
| ) | |
| # Launch the Gradio app | |
| iface.launch() | |