Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import traceback | |
| import re | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import login | |
| from pydantic import BaseModel, Field | |
| from typing import Optional | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 1000 | |
| image: Optional[str] = Field(None, description="This field should be None. If an image is detected, the request will be rejected.") | |
| class SentimentRequest(BaseModel): | |
| text: str | |
| # Use environment variable for token | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| else: | |
| print("Warning: No HF_TOKEN found in environment variables") | |
| # Set environment variables | |
| os.environ["TRITON_DISABLE"] = "1" | |
| os.environ["BNB_DISABLE_TRITON"] = "1" | |
| os.environ["USE_TORCH"] = "1" | |
| os.environ["BITSANDBYTES_NOWELCOME"] = "1" | |
| os.makedirs("/tmp/hf_cache", exist_ok=True) | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["TORCH_HOME"] = "/tmp/hf_cache" | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load models once at startup | |
| print("Loading models and tokenizers...") | |
| model_name = "mistralai/Mistral-7B-Instruct-v0.3" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| low_cpu_mem_usage=True, | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| sentiment_analyzer = pipeline( | |
| "text-classification", | |
| model="nlptown/bert-base-multilingual-uncased-sentiment", | |
| return_all_scores=True, | |
| device=device # Use GPU if available, otherwise CPU | |
| ) | |
| print("Models and tokenizers loaded successfully!") | |
| async def generate_text( | |
| request: GenerateRequest = None, | |
| prompt: str = None, | |
| max_tokens: int = 1000, | |
| file: Optional[UploadFile] = None | |
| ): | |
| if file: | |
| file = None # Just discard the file | |
| if request: | |
| user_prompt = request.prompt | |
| tokens = request.max_tokens | |
| else: | |
| user_prompt = prompt | |
| tokens = max_tokens | |
| if not user_prompt: | |
| return {"error": "No prompt provided"} | |
| try: | |
| formatted_prompt = f"<s>[INST] {user_prompt} [/INST]" | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # More aggressive cleaning to remove the user's message | |
| if raw_response.startswith(formatted_prompt): | |
| clean_response = raw_response[len(formatted_prompt):].strip() | |
| else: | |
| # Try to find where the instruction ends and the actual response begins | |
| clean_response = raw_response.split("[/INST]")[-1].strip() | |
| # Also remove the original prompt if it appears at the beginning | |
| if clean_response.startswith(user_prompt): | |
| clean_response = clean_response[len(user_prompt):].strip() | |
| # Further clean up any remaining tags | |
| clean_response = re.sub(r'</?s>|\[/?s\]|\[/?INST\]|\[/?INSR\]|\{/?INSST\}', '', clean_response).strip() | |
| return {"response": clean_response} | |
| except Exception as e: | |
| error_msg = str(e) | |
| error_trace = traceback.format_exc() | |
| print(f"Error generating text: {error_msg}") | |
| print(f"Traceback: {error_trace}") | |
| return {"error": error_msg, "traceback": error_trace} | |
| async def analyze_sentiment(request: SentimentRequest): | |
| try: | |
| result = sentiment_analyzer(request.text) | |
| scores = {score["label"]: score["score"] for score in result[0]} | |
| # Add debug logs | |
| print(f"Raw scores for '{request.text}': {scores}") | |
| # Define sentiment mapping for the 5-star model | |
| sentiment_mapping = { | |
| "1 star": "very_negative", | |
| "2 stars": "negative", | |
| "3 stars": "neutral", | |
| "4 stars": "positive", | |
| "5 stars": "very_positive" | |
| } | |
| # Get the highest score label and map it directly to sentiment | |
| highest_score_label = max(scores.items(), key=lambda x: x[1])[0] | |
| sentiment = sentiment_mapping[highest_score_label] | |
| print(f"Final sentiment: {sentiment}") | |
| return {"sentiment": sentiment, "raw_scores": scores} | |
| except Exception as e: | |
| error_msg = str(e) | |
| error_trace = traceback.format_exc() | |
| print(f"Error analyzing sentiment: {error_msg}") | |
| print(f"Traceback: {error_trace}") | |
| return {"error": error_msg, "traceback": error_trace} |