Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer | |
| from peft import AutoPeftModelForCausalLM | |
| import torch | |
| import re | |
| from transformers import StoppingCriteria, StoppingCriteriaList | |
| import os | |
| # Set cache directory and get token | |
| os.environ['HF_HOME'] = '/app/cache' | |
| hf_token = os.getenv('HF_TOKEN') | |
| class StopWordCriteria(StoppingCriteria): | |
| def __init__(self, tokenizer, stop_word): | |
| self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False) | |
| def __call__(self, input_ids, scores, **kwargs): | |
| if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id: | |
| return True | |
| return False | |
| def load_model(): | |
| try: | |
| # Ensure cache directory exists | |
| cache_dir = '/app/cache' | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Check for HF token | |
| if not hf_token: | |
| st.warning("HuggingFace token not found. Some models may not be accessible.") | |
| # Check CUDA availability | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| st.success(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| device = torch.device("cpu") | |
| st.warning("CUDA is not available. Using CPU.") | |
| # Fine-tuned model for generating scripts | |
| model_name = "Sidharthan/gemma2_scripter" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| token=hf_token, | |
| cache_dir=cache_dir | |
| ) | |
| except Exception as e: | |
| st.error(f"Error loading tokenizer: {str(e)}") | |
| if "401" in str(e): | |
| st.error("Authentication error. Please check your HuggingFace token.") | |
| raise e | |
| try: | |
| # Load model with appropriate device settings | |
| model = AutoPeftModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map=None, # We'll handle device placement manually | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| token=hf_token, | |
| cache_dir=cache_dir | |
| ) | |
| # Move model to device | |
| model = model.to(device) | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| if "401" in str(e): | |
| st.error("Authentication error. Please check your HuggingFace token.") | |
| elif "disk space" in str(e).lower(): | |
| st.error("Insufficient disk space in cache directory.") | |
| raise e | |
| except Exception as e: | |
| st.error(f"General error during model loading: {str(e)}") | |
| raise e | |
| def generate_script(tags, model, tokenizer, params): | |
| device = next(model.parameters()).device | |
| # Create prompt with tags | |
| prompt = f"<bos><start_of_turn>keywords\n{tags}<end_of_turn>\n<start_of_turn>script\n" | |
| # Tokenize and move to device | |
| inputs = tokenizer(prompt, return_tensors='pt') | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| stop_word = 'script' | |
| stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)]) | |
| try: | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=params['max_length'], | |
| do_sample=True, | |
| temperature=params['temperature'], | |
| top_p=params['top_p'], | |
| top_k=params['top_k'], | |
| repetition_penalty=params['repetition_penalty'], | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| stopping_criteria=stopping_criteria | |
| ) | |
| # Move outputs back to CPU for decoding | |
| outputs = outputs.cpu() | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean up response | |
| response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL) | |
| response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip() | |
| return response | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| st.error("GPU out of memory error. Try reducing max_length or using CPU.") | |
| return "Error: GPU out of memory" | |
| else: | |
| st.error(f"Error during generation: {str(e)}") | |
| return f"Error during generation: {str(e)}" | |
| def main(): | |
| st.title("🎥 YouTube Script Generator") | |
| # Sidebar for model parameters | |
| st.sidebar.title("Generation Parameters") | |
| params = { | |
| 'max_length': st.sidebar.slider('Max Length', 64, 1024, 512), | |
| 'temperature': st.sidebar.slider('Temperature', 0.1, 1.0, 0.7), | |
| 'top_p': st.sidebar.slider('Top P', 0.1, 1.0, 0.95), | |
| 'top_k': st.sidebar.slider('Top K', 1, 100, 50), | |
| 'repetition_penalty': st.sidebar.slider('Repetition Penalty', 1.0, 2.0, 1.2) | |
| } | |
| # Load model and tokenizer | |
| def get_model(): | |
| return load_model() | |
| try: | |
| model, tokenizer = get_model() | |
| # Tag input section | |
| st.markdown("### Add Tags") | |
| st.markdown("Enter tags separated by commas to generate a YouTube script") | |
| # Create columns for tag input and generate button | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...") | |
| with col2: | |
| generate_button = st.button("Generate Script", type="primary") | |
| # Generated script section | |
| if generate_button and tags: | |
| st.markdown("### Generated Script") | |
| with st.spinner("Generating script..."): | |
| script = generate_script(tags, model, tokenizer, params) | |
| st.text_area("Your script:", value=script, height=400) | |
| # Add download button | |
| st.download_button( | |
| label="Download Script", | |
| data=script, | |
| file_name="youtube_script.txt", | |
| mime="text/plain" | |
| ) | |
| elif generate_button and not tags: | |
| st.warning("Please enter some tags first!") | |
| except Exception as e: | |
| st.error("Failed to initialize the application. Please check the logs for details.") | |
| st.error(f"Error: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |