Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel # Use PeftModel for loading adapter | |
| import os | |
| import gc | |
| # --- Configuration --- | |
| # Base model ID (the one you fine-tuned FROM) | |
| base_model_id = "Qwen/Qwen2-0.5B" | |
| # Path WITHIN THE SPACE where you will upload your adapter files | |
| # Create a folder named 'adapter' in your Space and upload files there | |
| adapter_path = "./adapter" | |
| # Determine device (use GPU if available in the Space) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # --- Load Model and Tokenizer --- | |
| print(f"Loading base model: {base_model_id}") | |
| # Load base model in 4-bit | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| quantization_config=None, # Load base normally first | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, # Use appropriate dtype | |
| # device_map="auto", # <--- REMOVE THIS LINE | |
| device_map=device, # <--- CHANGE TO THIS (load directly to device) | |
| trust_remote_code=True | |
| ) | |
| base_model.config.use_cache = True # Enable cache for inference speed | |
| print(f"Base model loaded to device: {device}") | |
| # --- Load PEFT Adapter --- | |
| print(f"Loading PEFT adapter from: {adapter_path}") | |
| # Load the PEFT model (adapter) on top of the base model | |
| # Ensure the base_model is on the correct device before loading PEFT | |
| model = PeftModel.from_pretrained(base_model, adapter_path) | |
| print("Adapter loaded.") | |
| # --- Merge Adapter --- | |
| print("Merging adapter weights...") | |
| model = model.merge_and_unload() | |
| print("Adapter merged.") # Model should now be on the device specified earlier | |
| # --- Load Tokenizer --- | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) | |
| # Set padding token if necessary (using the logic from your training script) | |
| if tokenizer.pad_token is None: | |
| if tokenizer.eos_token: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print(f"Set tokenizer pad_token to eos_token: {tokenizer.pad_token}") | |
| else: | |
| print("Warning: EOS token not found, cannot set pad_token automatically.") | |
| tokenizer.padding_side = "left" # Important for generation | |
| print("Model and tokenizer loaded successfully.") | |
| # --- Inference Function --- | |
| def summarize_text(article_text): | |
| if not article_text: | |
| return "Please enter some text to summarize." | |
| # Format prompt for Qwen Base model (from your training script) | |
| prompt = f"Summarize the following text:\n\n{article_text}\n\nSummary:" | |
| try: | |
| print("Tokenizing input...") | |
| inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True).to(device) | |
| print("Generating summary...") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=100, # Max length of the summary | |
| temperature=0.6, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| # Decode only the generated part (after the prompt) | |
| response_ids = outputs[0][inputs["input_ids"].shape[1]:] | |
| summary = tokenizer.decode(response_ids, skip_special_tokens=True).strip() | |
| print("Summary generated.") | |
| # Clean up memory after generation | |
| del inputs, outputs | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return summary | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| return f"An error occurred: {e}" | |
| # --- Create Gradio Interface --- | |
| print("Creating Gradio interface...") | |
| iface = gr.Interface( | |
| fn=summarize_text, | |
| inputs=gr.Textbox(lines=10, placeholder="Paste the text you want to summarize here...", label="Article Text"), | |
| outputs=gr.Textbox(label="Generated Summary"), | |
| title="Qwen2-0.5B Base - Fine-tuned Summarizer (GRPO/QLoRA)", | |
| description="Enter text to get a summary generated by the fine-tuned Qwen2-0.5B base model.", | |
| examples=[ | |
| ["SUBREDDIT: r/relationships TITLE: I (f/22) have to figure out if I want to still know these girls or not and would hate to sound insulting POST: Not sure if this belongs here but it's worth a try... (rest of example text from your logs)"] | |
| # Add more examples if you like | |
| ] | |
| ) | |
| # --- Launch the App --- | |
| print("Launching Gradio app...") | |
| iface.launch() |