Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import pdfplumber | |
| import requests | |
| import faiss | |
| import json | |
| import torch | |
| from bs4 import BeautifulSoup | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import tempfile | |
| import logging | |
| from datetime import datetime | |
| from typing import List, Dict | |
| # Optimize CUDA memory management | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class CaseStudyGenerator: | |
| def __init__(self): | |
| self.model_name = "facebook/opt-2.7b" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Clear any reserved memory | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| model_kwargs = { | |
| 'torch_dtype': torch.float16 if self.device == "cuda" else torch.float32 | |
| } | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_name, **model_kwargs) | |
| if self.device == "cuda": | |
| self.model = self.model.to(self.device) | |
| self.model.gradient_checkpointing_enable() | |
| except RuntimeError as e: | |
| logger.warning(f"Memory issue detected: {e}, attempting 8-bit loading.") | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_name, quantization_config=quantization_config) | |
| except ImportError: | |
| logger.error("Missing 'bitsandbytes'. Install it using 'pip install -U bitsandbytes'") | |
| logger.info("Switching to CPU to continue operations.") | |
| self.device = "cpu" | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float32) | |
| self.generator = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=0 if self.device == "cuda" else -1, | |
| max_length=2048, | |
| num_return_sequences=1, | |
| temperature=0.8, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.dimension = 384 | |
| self.index = faiss.IndexFlatL2(self.dimension) | |
| self.stored_texts: List[Dict] = [] | |
| def clean_url(self, url: str) -> str: | |
| if not url.startswith(('http://', 'https://')): | |
| return "" | |
| return url.split('?')[0][:100] | |
| def fetch_articles(self, topic: str) -> List[str]: | |
| try: | |
| search_url = f"https://www.google.com/search?q={topic.replace(' ', '+')}+case+study+manufacturing+strategy" | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| response = requests.get(search_url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, "html.parser") | |
| articles = [self.clean_url(link.get("href", "")) for link in soup.find_all("a") if "google" not in link.get("href", "")] | |
| return articles[:5] or ["No articles found"] | |
| except Exception as e: | |
| logger.error(f"Error fetching articles: {str(e)}") | |
| return ["Error fetching articles"] | |
| def process_pdf(self, pdf_file) -> str: | |
| try: | |
| if pdf_file is None: | |
| return "No PDF provided" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf: | |
| temp_pdf.write(pdf_file.read()) | |
| temp_path = temp_pdf.name | |
| text = [] | |
| with pdfplumber.open(temp_path) as pdf: | |
| text = [page.extract_text().strip() for page in pdf.pages if page.extract_text()] | |
| os.unlink(temp_path) | |
| return "\n".join(text) or "No text extracted from PDF" | |
| except Exception as e: | |
| logger.error(f"Error processing PDF: {str(e)}") | |
| return "Error processing PDF" | |
| def generate_case_study(self, topic: str, pdf=None) -> str: | |
| try: | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| articles = self.fetch_articles(topic) | |
| pdf_text = self.process_pdf(pdf) if pdf else "No PDF provided" | |
| prompt = f"""Write a professional case study about {topic}. | |
| Background Information: | |
| - Topic: {topic} | |
| - Supporting Documents: {pdf_text[:500]} | |
| - Related Sources: {', '.join(articles)} | |
| Format your response as: | |
| 1. Executive Summary | |
| 2. Company Background | |
| 3. Challenge Analysis | |
| 4. Strategic Implementation | |
| 5. Results and Impact | |
| 6. Key Learnings | |
| """ | |
| output = self.generator( | |
| prompt, | |
| max_new_tokens=1024, | |
| num_return_sequences=1, | |
| temperature=0.8, | |
| top_p=0.95, | |
| do_sample=True, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3 | |
| ) | |
| case_study = output[0]['generated_text'].replace(prompt, "").strip() | |
| embedding = self.embedding_model.encode([case_study])[0] | |
| self.index.add(embedding.reshape(1, -1)) | |
| self.stored_texts.append({ | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| "topic": topic, | |
| "content": case_study | |
| }) | |
| return case_study | |
| except Exception as e: | |
| logger.error(f"Error generating case study: {str(e)}") | |
| return f"Error generating case study: {str(e)}" | |
| def retrieve_past_case_studies(self) -> str: | |
| try: | |
| if not self.stored_texts: | |
| return "No case studies generated yet." | |
| result = "" | |
| for idx, case in enumerate(self.stored_texts[-5:], start=1): | |
| result += f"Case Study {idx}\nTopic: {case['topic']}\nGenerated on: {case['timestamp']}\n\n{case['content']}\n\n=== End of Case Study ===\n\n" | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error retrieving past case studies: {str(e)}") | |
| return "Error retrieving past case studies" | |
| # Gradio interface | |
| with gr.Blocks() as app: | |
| gr.Markdown("# AI Case Study Generator (Optimized for GPU-T4 & CPU)") | |
| with gr.Row(): | |
| topic = gr.Textbox(label="Enter Topic") | |
| pdf = gr.File(label="Upload PDF", type="binary") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Case Study") | |
| retrieve_btn = gr.Button("Retrieve Past Case Studies") | |
| output = gr.Textbox(label="Generated Case Study", lines=20) | |
| past_cases = gr.Textbox(label="Past Case Studies", lines=20) | |
| generator = CaseStudyGenerator() | |
| generate_btn.click(generator.generate_case_study, inputs=[topic, pdf], outputs=output) | |
| retrieve_btn.click(generator.retrieve_past_case_studies, outputs=past_cases) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| app.launch(share=True) # Remove enable_queue | |
| # or, If using Gradio 3.x or later, use: | |
| # app.queue().launch(share=True) | |