Spaces:
Running
Running
| import gradio as gr | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| import huggingface_hub | |
| import os | |
| import torch | |
| # --- Configuration --- | |
| MODEL_ID = "Fastweb/FastwebMIIA-7B" | |
| HF_TOKEN = os.getenv("HF_TOKEN") # For Hugging Face Spaces, set this as a Secret | |
| # Global variable to store the pipeline | |
| text_generator_pipeline = None | |
| model_load_error = None # To store any error message during model loading | |
| # --- Hugging Face Login and Model Loading --- | |
| def load_model_and_pipeline(): | |
| global text_generator_pipeline, model_load_error | |
| if text_generator_pipeline is not None: | |
| print("Model already loaded.") | |
| return True # Already loaded | |
| if not HF_TOKEN: | |
| model_load_error = "Hugging Face token (HF_TOKEN) not found in Space secrets. Please add it and restart the Space." | |
| print(f"ERROR: {model_load_error}") | |
| return False | |
| try: | |
| print(f"Attempting to login to Hugging Face Hub with token...") | |
| huggingface_hub.login(token=HF_TOKEN) | |
| print("Login successful.") | |
| print(f"Loading tokenizer for {MODEL_ID}...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| use_fast=False # As recommended by the model card | |
| ) | |
| # Llama models often don't have a pad token set by default | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Tokenizer loaded.") | |
| print(f"Loading model {MODEL_ID}...") | |
| # For large models, specify dtype and device_map | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, # Use bfloat16 for better performance and memory if supported | |
| device_map="auto" # Automatically distribute model across available GPUs/CPU | |
| ) | |
| print("Model loaded.") | |
| text_generator_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| # device_map="auto" handles device placement, so no need for device=0 here | |
| ) | |
| print("Text generation pipeline created successfully.") | |
| model_load_error = None | |
| return True | |
| except Exception as e: | |
| model_load_error = f"Error loading model/pipeline: {str(e)}. Check model name, token, and Space resources (RAM/GPU)." | |
| print(f"ERROR: {model_load_error}") | |
| text_generator_pipeline = None # Ensure it's None on error | |
| return False | |
| # --- Text Analysis Function --- | |
| def analyze_text(text_input, file_upload, custom_instruction, max_new_tokens, temperature, top_p): | |
| global text_generator_pipeline, model_load_error | |
| if text_generator_pipeline is None: | |
| if model_load_error: | |
| return f"Model not loaded. Error: {model_load_error}" | |
| else: | |
| return "Model is not loaded or still loading. Please check Space logs for errors (especially OOM) and ensure HF_TOKEN is set and you've accepted model terms. If on CPU, it may take a very long time or fail due to memory." | |
| content_to_analyze = "" | |
| if file_upload is not None: | |
| try: | |
| with open(file_upload.name, 'r', encoding='utf-8') as f: | |
| content_to_analyze = f.read() | |
| if not content_to_analyze.strip() and not text_input.strip(): | |
| return "Uploaded file is empty and no direct text input provided. Please provide some text." | |
| elif not content_to_analyze.strip() and text_input.strip(): | |
| content_to_analyze = text_input | |
| except Exception as e: | |
| return f"Error reading uploaded file: {str(e)}" | |
| elif text_input: | |
| content_to_analyze = text_input | |
| else: | |
| return "Please provide text directly or upload a document." | |
| if not content_to_analyze.strip(): | |
| return "Input text is empty." | |
| # Using Llama 2 Chat Format | |
| # <s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_prompt} [/INST] | |
| # For text analysis, the "instruction" is the user_prompt, and the "text_input" is part of it. | |
| system_prompt = "You are a helpful AI assistant specialized in text analysis. Perform the requested task on the provided text." | |
| user_message = f"{custom_instruction}\n\nHere is the text:\n```\n{content_to_analyze}\n```" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| try: | |
| # Use tokenizer.apply_chat_template if available (transformers >= 4.34.0) | |
| prompt = text_generator_pipeline.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not use apply_chat_template ({e}). Falling back to manual formatting.") | |
| # Manual Llama 2 chat format | |
| prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_message} [/INST]" | |
| print(f"\n--- Sending to Model ---") | |
| print(f"Full Prompt:\n{prompt}") | |
| print(f"Max New Tokens: {max_new_tokens}, Temperature: {temperature}, Top P: {top_p}") | |
| print("------------------------\n") | |
| try: | |
| generated_outputs = text_generator_pipeline( | |
| prompt, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=True, | |
| temperature=float(temperature) if float(temperature) > 0.01 else 0.01, # Temperature 0 can be problematic | |
| top_p=float(top_p), | |
| num_return_sequences=1, | |
| eos_token_id=text_generator_pipeline.tokenizer.eos_token_id, | |
| pad_token_id=text_generator_pipeline.tokenizer.pad_token_id # Use the set pad_token | |
| ) | |
| response_full = generated_outputs[0]['generated_text'] | |
| # Extract only the assistant's response part | |
| # The model's actual answer starts after the [/INST] token. | |
| answer_marker = "[/INST]" | |
| if answer_marker in response_full: | |
| response_text = response_full.split(answer_marker, 1)[1].strip() | |
| else: | |
| # Fallback if the full prompt wasn't returned, might happen with some pipeline configs | |
| # or if the model didn't fully adhere to the template in its output. | |
| # This is less ideal, but better than nothing. | |
| response_text = response_full.replace(prompt, "").strip() # Try to remove the input prompt | |
| return response_text | |
| except Exception as e: | |
| return f"Error during text generation: {str(e)}" | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f""" | |
| # 📝 Text Analysis with {MODEL_ID} | |
| Test the capabilities of the `{MODEL_ID}` model for text analysis tasks on Italian or English texts. | |
| Provide an instruction and your text (directly or via upload). | |
| **Important:** Model loading can take a few minutes, especially on the first run or on CPU. | |
| This app is best run on a Hugging Face Space with GPU resources (e.g., T4-small or A10G-small) for this 7B model. | |
| """) | |
| with gr.Row(): | |
| status_textbox = gr.Textbox(label="Model Status", value="Initializing...", interactive=False, scale=3) | |
| current_hardware = os.getenv("SPACE_HARDWARE", "Unknown (likely local or unspecified)") | |
| gr.Markdown(f"Running on: **{current_hardware}**") | |
| with gr.Tab("Text Input & Analysis"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| instruction_prompt = gr.Textbox( | |
| label="Instruction for the Model (Cosa vuoi fare con il testo?)", | |
| value="Riassumi questo testo in 3 frasi concise.", | |
| lines=3, | |
| placeholder="Example: Riassumi questo testo. / Summarize this text. / Estrai le entità nominate. / Identify named entities." | |
| ) | |
| text_area_input = gr.Textbox(label="Enter Text Directly / Inserisci il testo direttamente", lines=10, placeholder="Paste your text here or upload a file below...") | |
| file_input = gr.File(label="Or Upload a Document (.txt) / O carica un documento (.txt)", file_types=['.txt']) | |
| with gr.Column(scale=3): | |
| output_text = gr.Textbox(label="Model Output / Risultato del Modello", lines=20, interactive=False) | |
| with gr.Accordion("Advanced Generation Parameters", open=False): | |
| max_new_tokens_slider = gr.Slider(minimum=10, maximum=2048, value=256, step=10, label="Max New Tokens") | |
| temperature_slider = gr.Slider(minimum=0.01, maximum=2.0, value=0.7, step=0.01, label="Temperature (higher is more creative, 0.01 for more deterministic)") | |
| top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P (nucleus sampling)") | |
| analyze_button = gr.Button("🧠 Analyze Text / Analizza Testo", variant="primary") | |
| analyze_button.click( | |
| fn=analyze_text, | |
| inputs=[text_area_input, file_input, instruction_prompt, max_new_tokens_slider, temperature_slider, top_p_slider], | |
| outputs=output_text | |
| ) | |
| # Load the model when the app starts. | |
| # This will update the status_textbox after attempting to load. | |
| def startup_load_model(): | |
| print("Gradio app starting, attempting to load model...") | |
| if load_model_and_pipeline(): | |
| return "Model loaded successfully and ready." | |
| else: | |
| return f"Failed to load model. Error: {model_load_error or 'Unknown error during startup. Check Space logs.'}" | |
| demo.load(startup_load_model, outputs=status_textbox) | |
| if __name__ == "__main__": | |
| # For local testing (ensure HF_TOKEN is set as an environment variable or you're logged in via CLI) | |
| # You would run: HF_TOKEN="your_hf_token_here" python app.py | |
| if not HF_TOKEN and "HF_TOKEN" not in os.environ: | |
| print("WARNING: HF_TOKEN environment variable not set.") | |
| print("For local execution, either set HF_TOKEN or ensure you are logged in via 'huggingface-cli login'.") | |
| try: | |
| from huggingface_hub import HfApi | |
| hf_api = HfApi() | |
| token = hf_api.token | |
| if token: | |
| os.environ['HF_TOKEN'] = token # Set it for the current process | |
| HF_TOKEN = token # also update the global variable used by the script | |
| print("Using token from huggingface-cli login.") | |
| else: | |
| print("Could not retrieve token from CLI login. Model access might fail.") | |
| except Exception as e: | |
| print(f"Could not check CLI login status: {e}. Model access might fail.") | |
| print("Launching Gradio interface...") | |
| demo.queue().launch(debug=True, share=False) # share=True for public link if local | |